mirror of
https://github.com/sd-webui/stable-diffusion-webui.git
synced 2024-12-16 08:22:33 +03:00
commit
f4493efe11
3
.dockerignore
Normal file
3
.dockerignore
Normal file
@ -0,0 +1,3 @@
|
||||
models/
|
||||
outputs/
|
||||
src/
|
@ -6,9 +6,13 @@ CONDA_FORCE_UPDATE=false
|
||||
# (useful to set to false after you're sure the model files are already in place)
|
||||
VALIDATE_MODELS=true
|
||||
|
||||
#Automatically relaunch the webui on crashes
|
||||
# Automatically relaunch the webui on crashes
|
||||
WEBUI_RELAUNCH=true
|
||||
|
||||
#Pass cli arguments to webui.py e.g:
|
||||
#WEBUI_ARGS=--gpu=1 --esrgan-gpu=1 --gfpgan-gpu=1
|
||||
# Which webui to launch
|
||||
# WEBUI_SCRIPT=webui_streamlit.py
|
||||
WEBUI_SCRIPT=webui.py
|
||||
|
||||
# Pass cli arguments to webui.py e.g:
|
||||
# WEBUI_ARGS=--optimized --extra-models-cpu --gpu=1 --esrgan-gpu=1 --gfpgan-gpu=1
|
||||
WEBUI_ARGS=
|
||||
|
7
.gitignore
vendored
7
.gitignore
vendored
@ -47,13 +47,18 @@ MANIFEST
|
||||
.env_updated
|
||||
condaenv.*.requirements.txt
|
||||
|
||||
# Visual Studio directories
|
||||
.vs/
|
||||
.vscode/
|
||||
|
||||
# =========================================================================== #
|
||||
# Repo-specific
|
||||
# =========================================================================== #
|
||||
/configs/webui/userconfig_streamlit.yaml
|
||||
/custom-conda-path.txt
|
||||
/src/*
|
||||
/outputs/*
|
||||
/outputs
|
||||
/model_cache
|
||||
/log/**/*.png
|
||||
/log/log.csv
|
||||
/flagged/*
|
||||
|
@ -1,6 +1,10 @@
|
||||
FROM nvidia/cuda:11.3.1-runtime-ubuntu20.04
|
||||
|
||||
ENV DEBIAN_FRONTEND=noninteractive
|
||||
ENV DEBIAN_FRONTEND=noninteractive \
|
||||
PYTHONUNBUFFERED=1 \
|
||||
PYTHONIOENCODING=UTF-8 \
|
||||
CONDA_DIR=/opt/conda
|
||||
|
||||
WORKDIR /sd
|
||||
|
||||
SHELL ["/bin/bash", "-c"]
|
||||
@ -11,7 +15,6 @@ RUN apt-get update && \
|
||||
rm -rf /var/lib/apt/lists/*
|
||||
|
||||
# Install miniconda
|
||||
ENV CONDA_DIR /opt/conda
|
||||
RUN wget -O ~/miniconda.sh -q --show-progress --progress=bar:force https://repo.anaconda.com/miniconda/Miniconda3-latest-Linux-x86_64.sh && \
|
||||
/bin/bash ~/miniconda.sh -b -p $CONDA_DIR && \
|
||||
rm ~/miniconda.sh
|
||||
@ -20,7 +23,7 @@ ENV PATH=$CONDA_DIR/bin:$PATH
|
||||
# Install font for prompt matrix
|
||||
COPY /data/DejaVuSans.ttf /usr/share/fonts/truetype/
|
||||
|
||||
EXPOSE 7860
|
||||
EXPOSE 7860 8501
|
||||
|
||||
COPY ./entrypoint.sh /sd/
|
||||
ENTRYPOINT /sd/entrypoint.sh
|
||||
|
@ -46,8 +46,8 @@ Features:
|
||||
|
||||
* Gradio GUI: Idiot-proof, fully featured frontend for both txt2img and img2img generation
|
||||
* No more manually typing parameters, now all you have to do is write your prompt and adjust sliders
|
||||
* GFPGAN Face Correction 🔥: [Download the model](https://github.com/sd-webui/stable-diffusion-webui#gfpgan)Automatically correct distorted faces with a built-in GFPGAN option, fixes them in less than half a second
|
||||
* RealESRGAN Upscaling 🔥: [Download the models](https://github.com/sd-webui/stable-diffusion-webui#realesrgan) Boosts the resolution of images with a built-in RealESRGAN option
|
||||
* GFPGAN Face Correction 🔥: [Download the model](https://github.com/sd-webui/stable-diffusion-webui/wiki/Installation#optional-additional-models) Automatically correct distorted faces with a built-in GFPGAN option, fixes them in less than half a second
|
||||
* RealESRGAN Upscaling 🔥: [Download the models](https://github.com/sd-webui/stable-diffusion-webui/wiki/Installation#optional-additional-models) Boosts the resolution of images with a built-in RealESRGAN option
|
||||
* :computer: esrgan/gfpgan on cpu support :computer:
|
||||
* Textual inversion 🔥: [info](https://textual-inversion.github.io/) - requires enabling, see [here](https://github.com/hlky/sd-enable-textual-inversion), script works as usual without it enabled
|
||||
* Advanced img2img editor :art: :fire: :art:
|
||||
@ -106,7 +106,7 @@ that are not in original script.
|
||||
|
||||
### GFPGAN
|
||||
Lets you improve faces in pictures using the GFPGAN model. There is a checkbox in every tab to use GFPGAN at 100%, and
|
||||
also a separate tab that just allows you to use GFPGAN on any picture, with a slider that controls how strongthe effect is.
|
||||
also a separate tab that just allows you to use GFPGAN on any picture, with a slider that controls how strong the effect is.
|
||||
|
||||
![](images/GFPGAN.png)
|
||||
|
||||
|
@ -12,8 +12,9 @@ txt2img:
|
||||
# 5: Write sample info files
|
||||
# 6: write sample info to log file
|
||||
# 7: jpg samples
|
||||
# 8: Fix faces using GFPGAN
|
||||
# 9: Upscale images using RealESRGAN
|
||||
# 8: Filter NSFW content
|
||||
# 9: Fix faces using GFPGAN
|
||||
# 10: Upscale images using RealESRGAN
|
||||
toggles: [1, 2, 3, 4, 5]
|
||||
sampler_name: k_lms
|
||||
ddim_eta: 0.0 # legacy name, applies to all algorithms.
|
||||
@ -40,8 +41,10 @@ img2img:
|
||||
# 6: Sort samples by prompt
|
||||
# 7: Write sample info files
|
||||
# 8: jpg samples
|
||||
# 9: Fix faces using GFPGAN
|
||||
# 10: Upscale images using Real-ESRGAN
|
||||
# 9: Color correction
|
||||
# 10: Filter NSFW content
|
||||
# 11: Fix faces using GFPGAN
|
||||
# 12: Upscale images using Real-ESRGAN
|
||||
toggles: [1, 4, 5, 6, 7]
|
||||
sampler_name: k_lms
|
||||
ddim_eta: 0.0
|
||||
|
@ -1,14 +1,19 @@
|
||||
# UI defaults configuration file. It is automatically loaded if located at configs/webui/webui_streamlit.yaml.
|
||||
# Any changes made here will be available automatically on the web app without having to stop it.
|
||||
# You may add overrides in a file named "userconfig_streamlit.yaml" in this folder, which can contain any subset
|
||||
# of the properties below.
|
||||
general:
|
||||
gpu: 0
|
||||
outdir: outputs
|
||||
ckpt: "models/ldm/stable-diffusion-v1/model.ckpt"
|
||||
fp:
|
||||
name: 'embeddings/alex/embeddings_gs-11000.pt'
|
||||
default_model: "Stable Diffusion v1.4"
|
||||
default_model_config: "configs/stable-diffusion/v1-inference.yaml"
|
||||
default_model_path: "models/ldm/stable-diffusion-v1/model.ckpt"
|
||||
use_sd_concepts_library: True
|
||||
sd_concepts_library_folder: "models/custom/sd-concepts-library"
|
||||
GFPGAN_dir: "./src/gfpgan"
|
||||
RealESRGAN_dir: "./src/realesrgan"
|
||||
RealESRGAN_model: "RealESRGAN_x4plus"
|
||||
LDSR_dir: "./src/latent-diffusion"
|
||||
outdir_txt2img: outputs/txt2img-samples
|
||||
outdir_img2img: outputs/img2img-samples
|
||||
gfpgan_cpu: False
|
||||
@ -16,88 +21,161 @@ general:
|
||||
extra_models_cpu: False
|
||||
extra_models_gpu: False
|
||||
save_metadata: True
|
||||
save_format: "png"
|
||||
skip_grid: False
|
||||
skip_save: False
|
||||
grid_format: "jpg:95"
|
||||
n_rows: -1
|
||||
no_verify_input: False
|
||||
no_half: False
|
||||
use_float16: False
|
||||
precision: "autocast"
|
||||
optimized: False
|
||||
optimized_turbo: False
|
||||
optimized_config: "optimizedSD/v1-inference.yaml"
|
||||
enable_attention_slicing: False
|
||||
enable_minimal_memory_usage : False
|
||||
update_preview: True
|
||||
update_preview_frequency: 1
|
||||
update_preview_frequency: 5
|
||||
|
||||
txt2img:
|
||||
prompt:
|
||||
height: 512
|
||||
width: 512
|
||||
cfg_scale: 5.0
|
||||
cfg_scale: 7.5
|
||||
seed: ""
|
||||
batch_count: 1
|
||||
batch_size: 1
|
||||
sampling_steps: 50
|
||||
default_sampler: "k_lms"
|
||||
sampling_steps: 30
|
||||
default_sampler: "k_euler"
|
||||
separate_prompts: False
|
||||
update_preview: True
|
||||
update_preview_frequency: 5
|
||||
normalize_prompt_weights: True
|
||||
save_individual_images: True
|
||||
save_grid: True
|
||||
group_by_prompt: True
|
||||
save_as_jpg: False
|
||||
use_GFPGAN: True
|
||||
use_RealESRGAN: True
|
||||
use_GFPGAN: False
|
||||
use_RealESRGAN: False
|
||||
RealESRGAN_model: "RealESRGAN_x4plus"
|
||||
variant_amount: 0.0
|
||||
variant_seed: ""
|
||||
write_info_files: True
|
||||
slider_steps: {
|
||||
sampling: 1
|
||||
}
|
||||
slider_bounds: {
|
||||
sampling: {
|
||||
lower: 1,
|
||||
upper: 150
|
||||
}
|
||||
}
|
||||
|
||||
txt2vid:
|
||||
default_model: "CompVis/stable-diffusion-v1-4"
|
||||
custom_models_list: ["CompVis/stable-diffusion-v1-4", "naclbit/trinart_stable_diffusion_v2", "hakurei/waifu-diffusion", "osanseviero/BigGAN-deep-128"]
|
||||
prompt:
|
||||
height: 512
|
||||
width: 512
|
||||
cfg_scale: 7.5
|
||||
seed: ""
|
||||
batch_count: 1
|
||||
batch_size: 1
|
||||
sampling_steps: 30
|
||||
num_inference_steps: 200
|
||||
default_sampler: "k_euler"
|
||||
scheduler_name: "klms"
|
||||
separate_prompts: False
|
||||
update_preview: True
|
||||
update_preview_frequency: 5
|
||||
dynamic_preview_frequency: True
|
||||
normalize_prompt_weights: True
|
||||
save_individual_images: True
|
||||
save_video: True
|
||||
group_by_prompt: True
|
||||
write_info_files: True
|
||||
do_loop: False
|
||||
save_as_jpg: False
|
||||
use_GFPGAN: False
|
||||
use_RealESRGAN: False
|
||||
RealESRGAN_model: "RealESRGAN_x4plus"
|
||||
variant_amount: 0.0
|
||||
variant_seed: ""
|
||||
beta_start: 0.00085
|
||||
beta_end: 0.012
|
||||
beta_scheduler_type: "linear"
|
||||
max_frames: 1000
|
||||
slider_steps: {
|
||||
sampling: 1
|
||||
}
|
||||
slider_bounds: {
|
||||
sampling: {
|
||||
lower: 1,
|
||||
upper: 150
|
||||
}
|
||||
}
|
||||
|
||||
img2img:
|
||||
prompt:
|
||||
sampling_steps: 50
|
||||
# Adding an int to toggles enables the corresponding feature.
|
||||
# 0: Create prompt matrix (separate multiple prompts using |, and get all combinations of them)
|
||||
# 1: Normalize Prompt Weights (ensure sum of weights add up to 1.0)
|
||||
# 2: Loopback (use images from previous batch when creating next batch)
|
||||
# 3: Random loopback seed
|
||||
# 4: Save individual images
|
||||
# 5: Save grid
|
||||
# 6: Sort samples by prompt
|
||||
# 7: Write sample info files
|
||||
# 8: jpg samples
|
||||
# 9: Fix faces using GFPGAN
|
||||
# 10: Upscale images using Real-ESRGAN
|
||||
sampler_name: k_lms
|
||||
denoising_strength: 0.45
|
||||
# 0: Keep masked area
|
||||
# 1: Regenerate only masked area
|
||||
mask_mode: 0
|
||||
# 0: Just resize
|
||||
# 1: Crop and resize
|
||||
# 2: Resize and fill
|
||||
resize_mode: 0
|
||||
# Leave blank for random seed:
|
||||
seed: ""
|
||||
ddim_eta: 0.0
|
||||
cfg_scale: 5.0
|
||||
batch_count: 1
|
||||
batch_size: 1
|
||||
height: 512
|
||||
width: 512
|
||||
# Textual inversion embeddings file path:
|
||||
fp: ""
|
||||
loopback: True
|
||||
random_seed_loopback: True
|
||||
separate_prompts: False
|
||||
normalize_prompt_weights: True
|
||||
save_individual_images: True
|
||||
save_grid: True
|
||||
group_by_prompt: True
|
||||
save_as_jpg: False
|
||||
use_GFPGAN: True
|
||||
use_RealESRGAN: True
|
||||
RealESRGAN_model: "RealESRGAN_x4plus"
|
||||
variant_amount: 0.0
|
||||
variant_seed: ""
|
||||
prompt:
|
||||
sampling_steps: 30
|
||||
# Adding an int to toggles enables the corresponding feature.
|
||||
# 0: Create prompt matrix (separate multiple prompts using |, and get all combinations of them)
|
||||
# 1: Normalize Prompt Weights (ensure sum of weights add up to 1.0)
|
||||
# 2: Loopback (use images from previous batch when creating next batch)
|
||||
# 3: Random loopback seed
|
||||
# 4: Save individual images
|
||||
# 5: Save grid
|
||||
# 6: Sort samples by prompt
|
||||
# 7: Write sample info files
|
||||
# 8: jpg samples
|
||||
# 9: Fix faces using GFPGAN
|
||||
# 10: Upscale images using Real-ESRGAN
|
||||
sampler_name: "k_euler"
|
||||
denoising_strength: 0.75
|
||||
# 0: Keep masked area
|
||||
# 1: Regenerate only masked area
|
||||
mask_mode: 0
|
||||
mask_restore: False
|
||||
# 0: Just resize
|
||||
# 1: Crop and resize
|
||||
# 2: Resize and fill
|
||||
resize_mode: 0
|
||||
# Leave blank for random seed:
|
||||
seed: ""
|
||||
ddim_eta: 0.0
|
||||
cfg_scale: 7.5
|
||||
batch_count: 1
|
||||
batch_size: 1
|
||||
height: 512
|
||||
width: 512
|
||||
# Textual inversion embeddings file path:
|
||||
fp: ""
|
||||
loopback: True
|
||||
random_seed_loopback: True
|
||||
separate_prompts: False
|
||||
update_preview: True
|
||||
update_preview_frequency: 5
|
||||
normalize_prompt_weights: True
|
||||
save_individual_images: True
|
||||
save_grid: True
|
||||
group_by_prompt: True
|
||||
save_as_jpg: False
|
||||
use_GFPGAN: False
|
||||
use_RealESRGAN: False
|
||||
RealESRGAN_model: "RealESRGAN_x4plus"
|
||||
variant_amount: 0.0
|
||||
variant_seed: ""
|
||||
write_info_files: True
|
||||
slider_steps: {
|
||||
sampling: 1
|
||||
}
|
||||
slider_bounds: {
|
||||
sampling: {
|
||||
lower: 1,
|
||||
upper: 150
|
||||
}
|
||||
}
|
||||
|
||||
gfpgan:
|
||||
strength: 100
|
||||
|
||||
|
@ -2,7 +2,7 @@ version: '3.3'
|
||||
|
||||
services:
|
||||
stable-diffusion:
|
||||
container_name: sd
|
||||
container_name: sd-webui
|
||||
build:
|
||||
context: .
|
||||
dockerfile: Dockerfile
|
||||
@ -12,6 +12,7 @@ services:
|
||||
volumes:
|
||||
- .:/sd
|
||||
- ./outputs:/sd/outputs
|
||||
- ./model_cache:/sd/model_cache
|
||||
- conda_env:/opt/conda
|
||||
- root_profile:/root
|
||||
ports:
|
||||
@ -21,7 +22,7 @@ services:
|
||||
resources:
|
||||
reservations:
|
||||
devices:
|
||||
- capabilities: [gpu]
|
||||
- capabilities: [ gpu ]
|
||||
|
||||
volumes:
|
||||
conda_env:
|
||||
|
9
docker-reset.sh
Normal file → Executable file
9
docker-reset.sh
Normal file → Executable file
@ -10,12 +10,13 @@ echo $(pwd)
|
||||
read -p "Is the directory above correct to run reset on? (y/n) " -n 1 DIRCONFIRM
|
||||
if [[ $DIRCONFIRM =~ ^[Yy]$ ]]; then
|
||||
docker compose down
|
||||
docker image rm stable-diffusion_stable-diffusion:latest
|
||||
docker volume rm stable-diffusion_conda_env
|
||||
docker volume rm stable-diffusion_root_profile
|
||||
docker image rm stable-diffusion-webui_stable-diffusion:latest
|
||||
docker volume rm stable-diffusion-webui_conda_env
|
||||
docker volume rm stable-diffusion-webui_root_profile
|
||||
echo "Remove ./src"
|
||||
sudo rm -rf src
|
||||
sudo rm -rf latent_diffusion.egg-info
|
||||
sudo rm -rf gfpgan
|
||||
sudo rm -rf sd_webui.egg-info
|
||||
sudo rm .env_updated
|
||||
else
|
||||
echo "Exited without resetting"
|
||||
|
@ -3,26 +3,36 @@
|
||||
# Starts the gui inside the docker container using the conda env
|
||||
#
|
||||
|
||||
# set -x
|
||||
|
||||
SCRIPT_DIR=$( cd -- "$( dirname -- "${BASH_SOURCE[0]}" )" &> /dev/null && pwd )
|
||||
cd $SCRIPT_DIR
|
||||
export PYTHONPATH=$SCRIPT_DIR
|
||||
|
||||
MODEL_DIR="${SCRIPT_DIR}/model_cache"
|
||||
# Array of model files to pre-download
|
||||
# local filename
|
||||
# local path in container (no trailing slash)
|
||||
# download URL
|
||||
# sha256sum
|
||||
MODEL_FILES=(
|
||||
'model.ckpt /sd/models/ldm/stable-diffusion-v1 https://www.googleapis.com/storage/v1/b/aai-blog-files/o/sd-v1-4.ckpt?alt=media fe4efff1e174c627256e44ec2991ba279b3816e364b49f9be2abc0b3ff3f8556'
|
||||
'GFPGANv1.3.pth /sd/src/gfpgan/experiments/pretrained_models https://github.com/TencentARC/GFPGAN/releases/download/v1.3.0/GFPGANv1.3.pth c953a88f2727c85c3d9ae72e2bd4846bbaf59fe6972ad94130e23e7017524a70'
|
||||
'RealESRGAN_x4plus.pth /sd/src/realesrgan/experiments/pretrained_models https://github.com/xinntao/Real-ESRGAN/releases/download/v0.1.0/RealESRGAN_x4plus.pth 4fa0d38905f75ac06eb49a7951b426670021be3018265fd191d2125df9d682f1'
|
||||
'RealESRGAN_x4plus_anime_6B.pth /sd/src/realesrgan/experiments/pretrained_models https://github.com/xinntao/Real-ESRGAN/releases/download/v0.2.2.4/RealESRGAN_x4plus_anime_6B.pth f872d837d3c90ed2e05227bed711af5671a6fd1c9f7d7e91c911a61f155e99da'
|
||||
'model.ckpt models/ldm/stable-diffusion-v1 https://www.googleapis.com/storage/v1/b/aai-blog-files/o/sd-v1-4.ckpt?alt=media fe4efff1e174c627256e44ec2991ba279b3816e364b49f9be2abc0b3ff3f8556'
|
||||
'GFPGANv1.3.pth src/gfpgan/experiments/pretrained_models https://github.com/TencentARC/GFPGAN/releases/download/v1.3.0/GFPGANv1.3.pth c953a88f2727c85c3d9ae72e2bd4846bbaf59fe6972ad94130e23e7017524a70'
|
||||
'RealESRGAN_x4plus.pth src/realesrgan/experiments/pretrained_models https://github.com/xinntao/Real-ESRGAN/releases/download/v0.1.0/RealESRGAN_x4plus.pth 4fa0d38905f75ac06eb49a7951b426670021be3018265fd191d2125df9d682f1'
|
||||
'RealESRGAN_x4plus_anime_6B.pth src/realesrgan/experiments/pretrained_models https://github.com/xinntao/Real-ESRGAN/releases/download/v0.2.2.4/RealESRGAN_x4plus_anime_6B.pth f872d837d3c90ed2e05227bed711af5671a6fd1c9f7d7e91c911a61f155e99da'
|
||||
'project.yaml src/latent-diffusion/experiments/pretrained_models https://heibox.uni-heidelberg.de/f/31a76b13ea27482981b4/?dl=1 9d6ad53c5dafeb07200fb712db14b813b527edd262bc80ea136777bdb41be2ba'
|
||||
'model.ckpt src/latent-diffusion/experiments/pretrained_models https://heibox.uni-heidelberg.de/f/578df07c8fc04ffbadf3/?dl=1 c209caecac2f97b4bb8f4d726b70ac2ac9b35904b7fc99801e1f5e61f9210c13'
|
||||
)
|
||||
|
||||
# Conda environment installs/updates
|
||||
# @see https://github.com/ContinuumIO/docker-images/issues/89#issuecomment-467287039
|
||||
ENV_NAME="ldm"
|
||||
ENV_FILE="/sd/environment.yaml"
|
||||
ENV_FILE="${SCRIPT_DIR}/environment.yaml"
|
||||
ENV_UPDATED=0
|
||||
ENV_MODIFIED=$(date -r $ENV_FILE "+%s")
|
||||
ENV_MODIFED_FILE="/sd/.env_updated"
|
||||
ENV_MODIFED_FILE="${SCRIPT_DIR}/.env_updated"
|
||||
if [[ -f $ENV_MODIFED_FILE ]]; then ENV_MODIFIED_CACHED=$(<${ENV_MODIFED_FILE}); else ENV_MODIFIED_CACHED=0; fi
|
||||
export PIP_EXISTS_ACTION=w
|
||||
|
||||
# Create/update conda env if needed
|
||||
if ! conda env list | grep ".*${ENV_NAME}.*" >/dev/null 2>&1; then
|
||||
@ -51,54 +61,67 @@ conda info | grep active
|
||||
# Function to checks for valid hash for model files and download/replaces if invalid or does not exist
|
||||
validateDownloadModel() {
|
||||
local file=$1
|
||||
local path=$2
|
||||
local path="${SCRIPT_DIR}/${2}"
|
||||
local url=$3
|
||||
local hash=$4
|
||||
|
||||
echo "checking ${file}..."
|
||||
sha256sum --check --status <<< "${hash} ${path}/${file}"
|
||||
sha256sum --check --status <<< "${hash} ${MODEL_DIR}/${file}.${hash}"
|
||||
if [[ $? == "1" ]]; then
|
||||
echo "Downloading: ${url} please wait..."
|
||||
mkdir -p ${path}
|
||||
wget --output-document=${path}/${file} --no-verbose --show-progress --progress=dot:giga ${url}
|
||||
echo "saved ${file}"
|
||||
wget --output-document=${MODEL_DIR}/${file}.${hash} --no-verbose --show-progress --progress=dot:giga ${url}
|
||||
ln -sf ${MODEL_DIR}/${file}.${hash} ${path}/${file}
|
||||
if [[ -e "${path}/${file}" ]]; then
|
||||
echo "saved ${file}"
|
||||
else
|
||||
echo "error saving ${path}/${file}!"
|
||||
exit 1
|
||||
fi
|
||||
else
|
||||
echo -e "${file} is valid!\n"
|
||||
if [[ ! -e ${path}/${file} || ! -L ${path}/${file} ]]; then
|
||||
mkdir -p ${path}
|
||||
ln -sf ${MODEL_DIR}/${file}.${hash} ${path}/${file}
|
||||
echo -e "linked valid ${file}\n"
|
||||
else
|
||||
echo -e "${file} is valid!\n"
|
||||
fi
|
||||
fi
|
||||
}
|
||||
|
||||
# Validate model files
|
||||
if [[ -z $VALIDATE_MODELS || $VALIDATE_MODELS == "true" ]]; then
|
||||
echo "Validating model files..."
|
||||
for models in "${MODEL_FILES[@]}"; do
|
||||
model=($models)
|
||||
echo "Validating model files..."
|
||||
for models in "${MODEL_FILES[@]}"; do
|
||||
model=($models)
|
||||
if [[ ! -e ${model[1]}/${model[0]} || ! -L ${model[1]}/${model[0]} || -z $VALIDATE_MODELS || $VALIDATE_MODELS == "true" ]]; then
|
||||
validateDownloadModel ${model[0]} ${model[1]} ${model[2]} ${model[3]}
|
||||
done
|
||||
fi
|
||||
fi
|
||||
done
|
||||
|
||||
# Launch web gui
|
||||
cd /sd
|
||||
|
||||
if [[ -z $WEBUI_ARGS ]]; then
|
||||
launch_message="entrypoint.sh: Launching..."
|
||||
if [[ ! -z $WEBUI_SCRIPT && $WEBUI_SCRIPT == "webui_streamlit.py" ]]; then
|
||||
launch_command="streamlit run scripts/${WEBUI_SCRIPT:-webui.py} $WEBUI_ARGS"
|
||||
else
|
||||
launch_message="entrypoint.sh: Launching with arguments ${WEBUI_ARGS}"
|
||||
launch_command="python scripts/${WEBUI_SCRIPT:-webui.py} $WEBUI_ARGS"
|
||||
fi
|
||||
|
||||
launch_message="entrypoint.sh: Run ${launch_command}..."
|
||||
if [[ -z $WEBUI_RELAUNCH || $WEBUI_RELAUNCH == "true" ]]; then
|
||||
n=0
|
||||
while true; do
|
||||
|
||||
echo $launch_message
|
||||
|
||||
if (( $n > 0 )); then
|
||||
echo "Relaunch count: ${n}"
|
||||
fi
|
||||
python -u scripts/webui.py $WEBUI_ARGS
|
||||
|
||||
$launch_command
|
||||
|
||||
echo "entrypoint.sh: Process is ending. Relaunching in 0.5s..."
|
||||
((n++))
|
||||
sleep 0.5
|
||||
done
|
||||
else
|
||||
echo $launch_message
|
||||
python -u scripts/webui.py $WEBUI_ARGS
|
||||
$launch_command
|
||||
fi
|
||||
|
@ -3,39 +3,47 @@ channels:
|
||||
- pytorch
|
||||
- defaults
|
||||
dependencies:
|
||||
- git
|
||||
- python=3.8.5
|
||||
- pip=20.3
|
||||
- cudatoolkit=11.3
|
||||
- git
|
||||
- numpy=1.22.3
|
||||
- pip=20.3
|
||||
- python=3.8.5
|
||||
- pytorch=1.11.0
|
||||
- scikit-image=0.19.2
|
||||
- torchvision=0.12.0
|
||||
- numpy=1.19.2
|
||||
- pip:
|
||||
- albumentations==0.4.3
|
||||
- opencv-python==4.1.2.30
|
||||
- opencv-python-headless==4.1.2.30
|
||||
- pudb==2019.2
|
||||
- imageio==2.9.0
|
||||
- imageio-ffmpeg==0.4.2
|
||||
- pytorch-lightning==1.4.2
|
||||
- omegaconf==2.1.1
|
||||
- test-tube>=0.7.5
|
||||
- einops==0.3.0
|
||||
- torch-fidelity==0.3.0
|
||||
- transformers==4.19.2
|
||||
- torchmetrics==0.6.0
|
||||
- kornia==0.6
|
||||
- gradio==3.1.6
|
||||
- accelerate==0.12.0
|
||||
- pynvml==11.4.1
|
||||
- basicsr>=1.3.4.0
|
||||
- facexlib>=0.2.3
|
||||
- python-slugify>=6.1.2
|
||||
- streamlit>=1.12.2
|
||||
- retry>=0.9.2
|
||||
- -e .
|
||||
- -e git+https://github.com/CompVis/taming-transformers#egg=taming-transformers
|
||||
- -e git+https://github.com/openai/CLIP#egg=clip
|
||||
- -e git+https://github.com/TencentARC/GFPGAN#egg=GFPGAN
|
||||
- -e git+https://github.com/xinntao/Real-ESRGAN#egg=realesrgan
|
||||
- -e git+https://github.com/hlky/k-diffusion-sd#egg=k_diffusion
|
||||
- -e .
|
||||
- -e git+https://github.com/devilismyfriend/latent-diffusion#egg=latent-diffusion
|
||||
- accelerate==0.12.0
|
||||
- albumentations==0.4.3
|
||||
- basicsr>=1.3.4.0
|
||||
- diffusers==0.3.0
|
||||
- einops==0.3.0
|
||||
- facexlib>=0.2.3
|
||||
- gradio==3.1.6
|
||||
- imageio-ffmpeg==0.4.2
|
||||
- imageio==2.9.0
|
||||
- kornia==0.6
|
||||
- omegaconf==2.1.1
|
||||
- opencv-python-headless==4.6.0.66
|
||||
- pandas==1.4.3
|
||||
- piexif==1.1.3
|
||||
- pudb==2019.2
|
||||
- pynvml==11.4.1
|
||||
- python-slugify>=6.1.2
|
||||
- pytorch-lightning==1.4.2
|
||||
- retry>=0.9.2
|
||||
- streamlit>=1.12.2
|
||||
- streamlit-on-Hover-tabs==1.0.1
|
||||
- streamlit-option-menu==0.3.2
|
||||
- streamlit_nested_layout
|
||||
- test-tube>=0.7.5
|
||||
- tensorboard
|
||||
- torch-fidelity==0.3.0
|
||||
- torchmetrics==0.6.0
|
||||
- transformers==4.19.2
|
||||
|
@ -1,15 +1,111 @@
|
||||
.css-18e3th9 {
|
||||
padding-top: 2rem;
|
||||
padding-bottom: 10rem;
|
||||
padding-left: 5rem;
|
||||
padding-right: 5rem;
|
||||
}
|
||||
.css-1d391kg {
|
||||
padding-top: 3.5rem;
|
||||
padding-right: 1rem;
|
||||
padding-bottom: 3.5rem;
|
||||
padding-left: 1rem;
|
||||
}
|
||||
/***********************************************************
|
||||
* Additional CSS for streamlit builtin components *
|
||||
************************************************************/
|
||||
|
||||
/* Tab name (e.g. Text-to-Image) */
|
||||
button[data-baseweb="tab"] {
|
||||
font-size: 25px;
|
||||
font-size: 25px; //improve legibility
|
||||
}
|
||||
|
||||
/* Image Container (only appear after run finished) */
|
||||
.css-du1fp8 {
|
||||
justify-content: center; //center the image, especially better looks in wide screen
|
||||
}
|
||||
|
||||
/* Streamlit header */
|
||||
.css-1avcm0n {
|
||||
background-color: transparent;
|
||||
}
|
||||
|
||||
/* Main streamlit container (below header) */
|
||||
.css-18e3th9 {
|
||||
padding-top: 2rem; //reduce the empty spaces
|
||||
}
|
||||
|
||||
/* @media only for widescreen, to ensure enough space to see all */
|
||||
@media (min-width: 1024px) {
|
||||
/* Main streamlit container (below header) */
|
||||
.css-18e3th9 {
|
||||
padding-top: 0px; //reduce the empty spaces, can go fully to the top on widescreen devices
|
||||
}
|
||||
}
|
||||
|
||||
/***********************************************************
|
||||
* Additional CSS for streamlit custom/3rd party components *
|
||||
************************************************************/
|
||||
/* For stream_on_hover */
|
||||
section[data-testid="stSidebar"] > div:nth-of-type(1) {
|
||||
background-color: #111;
|
||||
}
|
||||
|
||||
button[kind="header"] {
|
||||
background-color: transparent;
|
||||
color: rgb(180, 167, 141);
|
||||
}
|
||||
|
||||
@media (hover) {
|
||||
/* header element */
|
||||
header[data-testid="stHeader"] {
|
||||
/* display: none;*/ /*suggested behavior by streamlit hover components*/
|
||||
pointer-events: none; /* disable interaction of the transparent background */
|
||||
}
|
||||
|
||||
/* The button on the streamlit navigation menu */
|
||||
button[kind="header"] {
|
||||
/* display: none;*/ /*suggested behavior by streamlit hover components*/
|
||||
pointer-events: auto; /* enable interaction of the button even if parents intereaction disabled */
|
||||
}
|
||||
|
||||
/* added to avoid main sectors (all element to the right of sidebar from) moving */
|
||||
section[data-testid="stSidebar"] {
|
||||
width: 3.5% !important;
|
||||
min-width: 3.5% !important;
|
||||
}
|
||||
|
||||
/* The navigation menu specs and size */
|
||||
section[data-testid="stSidebar"] > div {
|
||||
height: 100%;
|
||||
width: 2% !important;
|
||||
min-width: 100% !important;
|
||||
position: relative;
|
||||
z-index: 1;
|
||||
top: 0;
|
||||
left: 0;
|
||||
background-color: #111;
|
||||
overflow-x: hidden;
|
||||
transition: 0.5s ease-in-out;
|
||||
padding-top: 0px;
|
||||
white-space: nowrap;
|
||||
}
|
||||
|
||||
/* The navigation menu open and close on hover and size */
|
||||
section[data-testid="stSidebar"] > div:hover {
|
||||
width: 300px !important;
|
||||
}
|
||||
}
|
||||
|
||||
@media (max-width: 272px) {
|
||||
section[data-testid="stSidebar"] > div {
|
||||
width: 15rem;
|
||||
}
|
||||
}
|
||||
|
||||
/***********************************************************
|
||||
* Additional CSS for other elements
|
||||
************************************************************/
|
||||
button[data-baseweb="tab"] {
|
||||
font-size: 20px;
|
||||
}
|
||||
|
||||
@media (min-width: 1200px){
|
||||
h1 {
|
||||
font-size: 1.75rem;
|
||||
}
|
||||
}
|
||||
#tabs-1-tabpanel-0 > div:nth-child(1) > div > div.stTabs.css-0.exp6ofz0 {
|
||||
width: 50rem;
|
||||
align-self: center;
|
||||
}
|
||||
div.gallery:hover {
|
||||
border: 1px solid #777;
|
||||
}
|
@ -3,6 +3,8 @@ from frontend.css_and_js import css, js, call_JS, js_parse_prompt, js_copy_txt2i
|
||||
from frontend.job_manager import JobManager
|
||||
import frontend.ui_functions as uifn
|
||||
import uuid
|
||||
import torch
|
||||
|
||||
|
||||
|
||||
def draw_gradio_ui(opt, img2img=lambda x: x, txt2img=lambda x: x, imgproc=lambda x: x, txt2img_defaults={},
|
||||
@ -36,8 +38,11 @@ def draw_gradio_ui(opt, img2img=lambda x: x, txt2img=lambda x: x, imgproc=lambda
|
||||
value=txt2img_defaults['cfg_scale'], elem_id='cfg_slider')
|
||||
txt2img_seed = gr.Textbox(label="Seed (blank to randomize)", lines=1, max_lines=1,
|
||||
value=txt2img_defaults["seed"])
|
||||
txt2img_batch_size = gr.Slider(minimum=1, maximum=50, step=1,
|
||||
label='Images per batch',
|
||||
value=txt2img_defaults['batch_size'])
|
||||
txt2img_batch_count = gr.Slider(minimum=1, maximum=50, step=1,
|
||||
label='Number of images to generate',
|
||||
label='Number of batches to generate',
|
||||
value=txt2img_defaults['n_iter'])
|
||||
|
||||
txt2img_job_ui = job_manager.draw_gradio_ui() if job_manager else None
|
||||
@ -51,11 +56,15 @@ def draw_gradio_ui(opt, img2img=lambda x: x, txt2img=lambda x: x, imgproc=lambda
|
||||
gr.Markdown(
|
||||
"Select an image from the gallery, then click one of the buttons below to perform an action.")
|
||||
with gr.Row(elem_id='txt2img_actions_row'):
|
||||
gr.Button("Copy to clipboard").click(fn=None,
|
||||
inputs=output_txt2img_gallery,
|
||||
outputs=[],
|
||||
# _js=js_copy_to_clipboard( 'txt2img_gallery_output')
|
||||
)
|
||||
gr.Button("Copy to clipboard").click(
|
||||
fn=None,
|
||||
inputs=output_txt2img_gallery,
|
||||
outputs=[],
|
||||
_js=call_JS(
|
||||
"copyImageFromGalleryToClipboard",
|
||||
fromId="txt2img_gallery_output"
|
||||
)
|
||||
)
|
||||
output_txt2img_copy_to_input_btn = gr.Button("Push to img2img")
|
||||
output_txt2img_to_imglab = gr.Button("Send to Lab", visible=True)
|
||||
|
||||
@ -91,9 +100,6 @@ def draw_gradio_ui(opt, img2img=lambda x: x, txt2img=lambda x: x, imgproc=lambda
|
||||
with gr.TabItem('Advanced'):
|
||||
txt2img_toggles = gr.CheckboxGroup(label='', choices=txt2img_toggles,
|
||||
value=txt2img_toggle_defaults, type="index")
|
||||
txt2img_batch_size = gr.Slider(minimum=1, maximum=8, step=1,
|
||||
label='Batch size (how many images are in a batch; memory-hungry)',
|
||||
value=txt2img_defaults['batch_size'])
|
||||
txt2img_realesrgan_model_name = gr.Dropdown(label='RealESRGAN model',
|
||||
choices=['RealESRGAN_x4plus',
|
||||
'RealESRGAN_x4plus_anime_6B'],
|
||||
@ -124,20 +130,27 @@ def draw_gradio_ui(opt, img2img=lambda x: x, txt2img=lambda x: x, imgproc=lambda
|
||||
inputs=txt2img_inputs,
|
||||
outputs=txt2img_outputs
|
||||
)
|
||||
use_queue = False
|
||||
else:
|
||||
use_queue = True
|
||||
|
||||
txt2img_btn.click(
|
||||
txt2img_func,
|
||||
txt2img_inputs,
|
||||
txt2img_outputs
|
||||
txt2img_outputs,
|
||||
api_name='txt2img',
|
||||
queue=use_queue
|
||||
)
|
||||
txt2img_prompt.submit(
|
||||
txt2img_func,
|
||||
txt2img_inputs,
|
||||
txt2img_outputs
|
||||
txt2img_outputs,
|
||||
queue=use_queue
|
||||
)
|
||||
|
||||
# txt2img_width.change(fn=uifn.update_dimensions_info, inputs=[txt2img_width, txt2img_height], outputs=txt2img_dimensions_info_text_box)
|
||||
# txt2img_height.change(fn=uifn.update_dimensions_info, inputs=[txt2img_width, txt2img_height], outputs=txt2img_dimensions_info_text_box)
|
||||
txt2img_width.change(fn=uifn.update_dimensions_info, inputs=[txt2img_width, txt2img_height], outputs=txt2img_dimensions_info_text_box)
|
||||
txt2img_height.change(fn=uifn.update_dimensions_info, inputs=[txt2img_width, txt2img_height], outputs=txt2img_dimensions_info_text_box)
|
||||
txt2img_dimensions_info_text_box.value = uifn.update_dimensions_info(txt2img_width.value, txt2img_height.value)
|
||||
|
||||
# Temporarily disable prompt parsing until memory issues could be solved
|
||||
# See #676
|
||||
@ -189,8 +202,9 @@ def draw_gradio_ui(opt, img2img=lambda x: x, txt2img=lambda x: x, imgproc=lambda
|
||||
with gr.TabItem("Editor Options"):
|
||||
with gr.Row():
|
||||
# disable Uncrop for now
|
||||
# choices=["Mask", "Crop", "Uncrop"]
|
||||
img2img_image_editor_mode = gr.Radio(choices=["Mask", "Crop"],
|
||||
choices=["Mask", "Crop", "Uncrop"]
|
||||
#choices=["Mask", "Crop"]
|
||||
img2img_image_editor_mode = gr.Radio(choices=choices,
|
||||
label="Image Editor Mode",
|
||||
value="Mask", elem_id='edit_mode_select',
|
||||
visible=True)
|
||||
@ -199,9 +213,13 @@ def draw_gradio_ui(opt, img2img=lambda x: x, txt2img=lambda x: x, imgproc=lambda
|
||||
value=img2img_mask_modes[img2img_defaults['mask_mode']],
|
||||
visible=True)
|
||||
|
||||
img2img_mask_blur_strength = gr.Slider(minimum=1, maximum=10, step=1,
|
||||
img2img_mask_restore = gr.Checkbox(label="Only modify regenerated parts of image",
|
||||
value=img2img_defaults['mask_restore'],
|
||||
visible=True)
|
||||
|
||||
img2img_mask_blur_strength = gr.Slider(minimum=1, maximum=100, step=1,
|
||||
label="How much blurry should the mask be? (to avoid hard edges)",
|
||||
value=3, visible=False)
|
||||
value=3, visible=True)
|
||||
|
||||
img2img_resize = gr.Radio(label="Resize mode",
|
||||
choices=["Just resize", "Crop and resize",
|
||||
@ -293,7 +311,7 @@ def draw_gradio_ui(opt, img2img=lambda x: x, txt2img=lambda x: x, imgproc=lambda
|
||||
img2img_height
|
||||
],
|
||||
[img2img_image_editor, img2img_image_mask, img2img_btn_editor, img2img_btn_mask,
|
||||
img2img_painterro_btn, img2img_mask, img2img_mask_blur_strength]
|
||||
img2img_painterro_btn, img2img_mask, img2img_mask_blur_strength, img2img_mask_restore]
|
||||
)
|
||||
|
||||
# img2img_image_editor_mode.change(
|
||||
@ -334,8 +352,8 @@ def draw_gradio_ui(opt, img2img=lambda x: x, txt2img=lambda x: x, imgproc=lambda
|
||||
)
|
||||
|
||||
img2img_func = img2img
|
||||
img2img_inputs = [img2img_prompt, img2img_image_editor_mode, img2img_mask,
|
||||
img2img_mask_blur_strength, img2img_steps, img2img_sampling, img2img_toggles,
|
||||
img2img_inputs = [img2img_prompt, img2img_image_editor_mode, img2img_mask, img2img_mask_blur_strength,
|
||||
img2img_mask_restore, img2img_steps, img2img_sampling, img2img_toggles,
|
||||
img2img_realesrgan_model_name, img2img_batch_count, img2img_cfg,
|
||||
img2img_denoising, img2img_seed, img2img_height, img2img_width, img2img_resize,
|
||||
img2img_image_editor, img2img_image_mask, img2img_embeddings]
|
||||
@ -349,11 +367,16 @@ def draw_gradio_ui(opt, img2img=lambda x: x, txt2img=lambda x: x, imgproc=lambda
|
||||
inputs=img2img_inputs,
|
||||
outputs=img2img_outputs,
|
||||
)
|
||||
use_queue = False
|
||||
else:
|
||||
use_queue = True
|
||||
|
||||
img2img_btn_mask.click(
|
||||
img2img_func,
|
||||
img2img_inputs,
|
||||
img2img_outputs
|
||||
img2img_outputs,
|
||||
api_name="img2img",
|
||||
queue=use_queue
|
||||
)
|
||||
|
||||
def img2img_submit_params():
|
||||
@ -383,6 +406,7 @@ def draw_gradio_ui(opt, img2img=lambda x: x, txt2img=lambda x: x, imgproc=lambda
|
||||
outputs=img2img_dimensions_info_text_box)
|
||||
img2img_height.change(fn=uifn.update_dimensions_info, inputs=[img2img_width, img2img_height],
|
||||
outputs=img2img_dimensions_info_text_box)
|
||||
img2img_dimensions_info_text_box.value = uifn.update_dimensions_info(img2img_width.value, img2img_height.value)
|
||||
|
||||
with gr.TabItem("Image Lab", id='imgproc_tab'):
|
||||
gr.Markdown("Post-process results")
|
||||
@ -397,8 +421,7 @@ def draw_gradio_ui(opt, img2img=lambda x: x, txt2img=lambda x: x, imgproc=lambda
|
||||
# value=gfpgan_defaults['strength'])
|
||||
# select folder with images to process
|
||||
with gr.TabItem('Batch Process'):
|
||||
imgproc_folder = gr.File(label="Batch Process", file_count="multiple", source="upload",
|
||||
interactive=True, type="file")
|
||||
imgproc_folder = gr.File(label="Batch Process", file_count="multiple", interactive=True, type="file")
|
||||
imgproc_pngnfo = gr.Textbox(label="PNG Metadata", placeholder="PngNfo", visible=False,
|
||||
max_lines=5)
|
||||
with gr.Row():
|
||||
@ -540,7 +563,7 @@ def draw_gradio_ui(opt, img2img=lambda x: x, txt2img=lambda x: x, imgproc=lambda
|
||||
imgproc_width, imgproc_cfg, imgproc_denoising, imgproc_seed,
|
||||
imgproc_gfpgan_strength, imgproc_ldsr_steps, imgproc_ldsr_pre_downSample,
|
||||
imgproc_ldsr_post_downSample],
|
||||
[imgproc_output])
|
||||
[imgproc_output], api_name="imgproc")
|
||||
|
||||
imgproc_source.change(
|
||||
uifn.get_png_nfo,
|
||||
@ -631,11 +654,12 @@ def draw_gradio_ui(opt, img2img=lambda x: x, txt2img=lambda x: x, imgproc=lambda
|
||||
"""
|
||||
gr.HTML("""
|
||||
<div id="90" style="max-width: 100%; font-size: 14px; text-align: center;" class="output-markdown gr-prose border-solid border border-gray-200 rounded gr-panel">
|
||||
<p>For help and advanced usage guides, visit the <a href="https://github.com/sd-webui/stable-diffusion-webui/wiki" target="_blank">Project Wiki</a></p>
|
||||
<p><a href="https://github.com/sd-webui/stable-diffusion-webui">Stable Diffusion WebUI</a> is an open-source project.
|
||||
If you would like to contribute to development or test bleeding edge builds, use the <a href="https://github.com/sd-webui/stable-diffusion-webui/tree/dev" target="_blank">dev branch</a>.</p>
|
||||
<p>For help and advanced usage guides, visit the <a href="https://github.com/hlky/stable-diffusion-webui/wiki" target="_blank">Project Wiki</a></p>
|
||||
<p>Stable Diffusion WebUI is an open-source project. You can find the latest stable builds on the <a href="https://github.com/hlky/stable-diffusion" target="_blank">main repository</a>.
|
||||
If you would like to contribute to development or test bleeding edge builds, you can visit the <a href="https://github.com/hlky/stable-diffusion-webui" target="_blank">developement repository</a>.</p>
|
||||
<p>Device ID {current_device_index}: {current_device_name}<br/>{total_device_count} total devices</p>
|
||||
</div>
|
||||
""")
|
||||
""".format(current_device_name=torch.cuda.get_device_name(), current_device_index=torch.cuda.current_device(), total_device_count=torch.cuda.device_count()))
|
||||
# Hack: Detect the load event on the frontend
|
||||
# Won't be needed in the next version of gradio
|
||||
# See the relevant PR: https://github.com/gradio-app/gradio/pull/2108
|
||||
|
57
frontend/image_metadata.py
Normal file
57
frontend/image_metadata.py
Normal file
@ -0,0 +1,57 @@
|
||||
''' Class to store image generation parameters to be stored as metadata in the image'''
|
||||
from __future__ import annotations
|
||||
from dataclasses import dataclass, asdict
|
||||
from typing import Dict, Optional
|
||||
from PIL import Image
|
||||
from PIL.PngImagePlugin import PngInfo
|
||||
import copy
|
||||
|
||||
@dataclass
|
||||
class ImageMetadata:
|
||||
prompt: str = None
|
||||
seed: str = None
|
||||
width: str = None
|
||||
height: str = None
|
||||
steps: str = None
|
||||
cfg_scale: str = None
|
||||
normalize_prompt_weights: str = None
|
||||
denoising_strength: str = None
|
||||
GFPGAN: str = None
|
||||
|
||||
def as_png_info(self) -> PngInfo:
|
||||
info = PngInfo()
|
||||
for key, value in self.as_dict().items():
|
||||
info.add_text(key, value)
|
||||
return info
|
||||
|
||||
def as_dict(self) -> Dict[str, str]:
|
||||
return {f"SD:{key}": str(value) for key, value in asdict(self).items() if value is not None}
|
||||
|
||||
@classmethod
|
||||
def set_on_image(cls, image: Image, metadata: ImageMetadata) -> None:
|
||||
''' Sets metadata on image, in both text form and as an ImageMetadata object '''
|
||||
if metadata:
|
||||
image.info = metadata.as_dict()
|
||||
else:
|
||||
metadata = ImageMetadata()
|
||||
image.info["ImageMetadata"] = copy.copy(metadata)
|
||||
|
||||
@classmethod
|
||||
def get_from_image(cls, image: Image) -> Optional[ImageMetadata]:
|
||||
''' Gets metadata from an image, first looking for an ImageMetadata,
|
||||
then if not found tries to construct one from the info '''
|
||||
metadata = image.info.get("ImageMetadata", None)
|
||||
if not metadata:
|
||||
found_metadata = False
|
||||
metadata = ImageMetadata()
|
||||
for key, value in image.info.items():
|
||||
if key.lower().startswith("sd:"):
|
||||
key = key[3:]
|
||||
if f"{key}" in metadata.__dict__:
|
||||
metadata.__dict__[key] = value
|
||||
found_metadata = True
|
||||
if not found_metadata:
|
||||
metadata = None
|
||||
if not metadata:
|
||||
print("Couldn't find metadata on image")
|
||||
return metadata
|
@ -1,7 +1,7 @@
|
||||
''' Provides simple job management for gradio, allowing viewing and stopping in-progress multi-batch generations '''
|
||||
from __future__ import annotations
|
||||
import gradio as gr
|
||||
from gradio.components import Component, Gallery
|
||||
from gradio.components import Component, Gallery, Slider
|
||||
from threading import Event, Timer
|
||||
from typing import Callable, List, Dict, Tuple, Optional, Any
|
||||
from dataclasses import dataclass, field
|
||||
@ -9,6 +9,7 @@ from functools import partial
|
||||
from PIL.Image import Image
|
||||
import uuid
|
||||
import traceback
|
||||
import time
|
||||
|
||||
|
||||
@dataclass(eq=True, frozen=True)
|
||||
@ -30,9 +31,21 @@ class JobInfo:
|
||||
session_key: str
|
||||
job_token: Optional[int] = None
|
||||
images: List[Image] = field(default_factory=list)
|
||||
active_image: Image = None
|
||||
rec_steps_enabled: bool = False
|
||||
rec_steps_imgs: List[Image] = field(default_factory=list)
|
||||
rec_steps_intrvl: int = None
|
||||
rec_steps_to_gallery: bool = False
|
||||
rec_steps_to_file: bool = False
|
||||
should_stop: Event = field(default_factory=Event)
|
||||
refresh_active_image_requested: Event = field(default_factory=Event)
|
||||
refresh_active_image_done: Event = field(default_factory=Event)
|
||||
stop_cur_iter: Event = field(default_factory=Event)
|
||||
active_iteration_cnt: int = field(default_factory=int)
|
||||
job_status: str = field(default_factory=str)
|
||||
finished: bool = False
|
||||
started: bool = False
|
||||
timestamp: float = None
|
||||
removed_output_idxs: List[int] = field(default_factory=list)
|
||||
|
||||
|
||||
@ -76,7 +89,7 @@ class JobManagerUi:
|
||||
'''
|
||||
return self._job_manager._wrap_func(
|
||||
func=func, inputs=inputs, outputs=outputs,
|
||||
refresh_btn=self._refresh_btn, stop_btn=self._stop_btn, status_text=self._status_text
|
||||
job_ui=self
|
||||
)
|
||||
|
||||
_refresh_btn: gr.Button
|
||||
@ -84,10 +97,19 @@ class JobManagerUi:
|
||||
_status_text: gr.Textbox
|
||||
_stop_all_session_btn: gr.Button
|
||||
_free_done_sessions_btn: gr.Button
|
||||
_active_image: gr.Image
|
||||
_active_image_stop_btn: gr.Button
|
||||
_active_image_refresh_btn: gr.Button
|
||||
_rec_steps_intrvl_sldr: gr.Slider
|
||||
_rec_steps_checkbox: gr.Checkbox
|
||||
_save_rec_steps_to_gallery_chkbx: gr.Checkbox
|
||||
_save_rec_steps_to_file_chkbx: gr.Checkbox
|
||||
_job_manager: JobManager
|
||||
|
||||
|
||||
class JobManager:
|
||||
JOB_MAX_START_TIME = 5.0 # How long can a job be stuck 'starting' before assuming it isn't running
|
||||
|
||||
def __init__(self, max_jobs: int):
|
||||
self._max_jobs: int = max_jobs
|
||||
self._avail_job_tokens: List[Any] = list(range(max_jobs))
|
||||
@ -102,11 +124,23 @@ class JobManager:
|
||||
'''
|
||||
assert gr.context.Context.block is not None, "draw_gradio_ui must be called within a 'gr.Blocks' 'with' context"
|
||||
with gr.Tabs():
|
||||
with gr.TabItem("Current Session"):
|
||||
with gr.TabItem("Job Controls"):
|
||||
with gr.Row():
|
||||
stop_btn = gr.Button("Stop", elem_id="stop", variant="secondary")
|
||||
refresh_btn = gr.Button("Refresh", elem_id="refresh", variant="secondary")
|
||||
stop_btn = gr.Button("Stop All Batches", elem_id="stop", variant="secondary")
|
||||
refresh_btn = gr.Button("Refresh Finished Batches", elem_id="refresh", variant="secondary")
|
||||
status_text = gr.Textbox(placeholder="Job Status", interactive=False, show_label=False)
|
||||
with gr.Row():
|
||||
active_image_stop_btn = gr.Button("Skip Active Batch", variant="secondary")
|
||||
active_image_refresh_btn = gr.Button("View Batch Progress", variant="secondary")
|
||||
active_image = gr.Image(type="pil", interactive=False, visible=False, elem_id="active_iteration_image")
|
||||
with gr.TabItem("Batch Progress Settings"):
|
||||
with gr.Row():
|
||||
record_steps_checkbox = gr.Checkbox(value=False, label="Enable Batch Progress Grid")
|
||||
record_steps_interval_slider = gr.Slider(
|
||||
value=3, label="Record Interval (steps)", minimum=1, maximum=25, step=1)
|
||||
with gr.Row() as record_steps_box:
|
||||
steps_to_gallery_checkbox = gr.Checkbox(value=False, label="Save Progress Grid to Gallery")
|
||||
steps_to_file_checkbox = gr.Checkbox(value=False, label="Save Progress Grid to File")
|
||||
with gr.TabItem("Maintenance"):
|
||||
with gr.Row():
|
||||
gr.Markdown(
|
||||
@ -118,9 +152,15 @@ class JobManager:
|
||||
free_done_sessions_btn = gr.Button(
|
||||
"Clear Finished Jobs", elem_id="clear_finished", variant="secondary"
|
||||
)
|
||||
|
||||
return JobManagerUi(_refresh_btn=refresh_btn, _stop_btn=stop_btn, _status_text=status_text,
|
||||
_stop_all_session_btn=stop_all_sessions_btn, _free_done_sessions_btn=free_done_sessions_btn,
|
||||
_job_manager=self)
|
||||
_active_image=active_image, _active_image_stop_btn=active_image_stop_btn,
|
||||
_active_image_refresh_btn=active_image_refresh_btn,
|
||||
_rec_steps_checkbox=record_steps_checkbox,
|
||||
_save_rec_steps_to_gallery_chkbx=steps_to_gallery_checkbox,
|
||||
_save_rec_steps_to_file_chkbx=steps_to_file_checkbox,
|
||||
_rec_steps_intrvl_sldr=record_steps_interval_slider, _job_manager=self)
|
||||
|
||||
def clear_all_finished_jobs(self):
|
||||
''' Removes all currently finished jobs, across all sessions.
|
||||
@ -134,6 +174,7 @@ class JobManager:
|
||||
for session in self._sessions.values():
|
||||
for job in session.jobs.values():
|
||||
job.should_stop.set()
|
||||
job.stop_cur_iter.set()
|
||||
|
||||
def _get_job_token(self, block: bool = False) -> Optional[int]:
|
||||
''' Attempts to acquire a job token, optionally blocking until available '''
|
||||
@ -175,6 +216,26 @@ class JobManager:
|
||||
job_info.should_stop.set()
|
||||
return "Stopping after current batch finishes"
|
||||
|
||||
def _refresh_cur_iter_func(self, func_key: FuncKey, session_key: str) -> List[Component]:
|
||||
''' Updates information from the active iteration '''
|
||||
session_info, job_info = self._get_call_info(func_key, session_key)
|
||||
if job_info is None:
|
||||
return [None, f"Session {session_key} was not running function {func_key}"]
|
||||
|
||||
job_info.refresh_active_image_requested.set()
|
||||
if job_info.refresh_active_image_done.wait(timeout=20.0):
|
||||
job_info.refresh_active_image_done.clear()
|
||||
return [gr.Image.update(value=job_info.active_image, visible=True), f"Sample iteration {job_info.active_iteration_cnt}"]
|
||||
return [gr.Image.update(visible=False), "Timed out getting image"]
|
||||
|
||||
def _stop_cur_iter_func(self, func_key: FuncKey, session_key: str) -> List[Component]:
|
||||
''' Marks that the active iteration should be stopped'''
|
||||
session_info, job_info = self._get_call_info(func_key, session_key)
|
||||
if job_info is None:
|
||||
return [None, f"Session {session_key} was not running function {func_key}"]
|
||||
job_info.stop_cur_iter.set()
|
||||
return [gr.Image.update(visible=False), "Stopping current iteration"]
|
||||
|
||||
def _get_call_info(self, func_key: FuncKey, session_key: str) -> Tuple[SessionInfo, JobInfo]:
|
||||
''' Helper to get the SessionInfo and JobInfo. '''
|
||||
session_info = self._sessions.get(session_key, None)
|
||||
@ -207,19 +268,22 @@ class JobManager:
|
||||
|
||||
def _pre_call_func(
|
||||
self, func_key: FuncKey, output_dummy_obj: Component, refresh_btn: gr.Button, stop_btn: gr.Button,
|
||||
status_text: gr.Textbox, session_key: str) -> List[Component]:
|
||||
status_text: gr.Textbox, active_image: gr.Image, active_refresh_btn: gr.Button, active_stop_btn: gr.Button,
|
||||
session_key: str) -> List[Component]:
|
||||
''' Called when a job is about to start '''
|
||||
session_info, job_info = self._get_call_info(func_key, session_key)
|
||||
|
||||
# If we didn't already get a token then queue up for one
|
||||
if job_info.job_token is None:
|
||||
job_info.token = self._get_job_token(block=True)
|
||||
job_info.job_token = self._get_job_token(block=True)
|
||||
|
||||
# Buttons don't seem to update unless value is set on them as well...
|
||||
return {output_dummy_obj: triggerChangeEvent(),
|
||||
refresh_btn: gr.Button.update(variant="primary", value=refresh_btn.value),
|
||||
stop_btn: gr.Button.update(variant="primary", value=stop_btn.value),
|
||||
status_text: gr.Textbox.update(value="Generation has started. Click 'Refresh' for updates")
|
||||
status_text: gr.Textbox.update(value="Generation has started. Click 'Refresh' to see finished images, 'View Batch Progress' for active images"),
|
||||
active_refresh_btn: gr.Button.update(variant="primary", value=active_refresh_btn.value),
|
||||
active_stop_btn: gr.Button.update(variant="primary", value=active_stop_btn.value),
|
||||
}
|
||||
|
||||
def _call_func(self, func_key: FuncKey, session_key: str) -> List[Component]:
|
||||
@ -228,12 +292,19 @@ class JobManager:
|
||||
if session_info is None or job_info is None:
|
||||
return []
|
||||
|
||||
job_info.started = True
|
||||
try:
|
||||
if job_info.should_stop.is_set():
|
||||
raise Exception(f"Job {job_info} requested a stop before execution began")
|
||||
outputs = job_info.func(*job_info.inputs, job_info=job_info)
|
||||
except Exception as e:
|
||||
job_info.job_status = f"Error: {e}"
|
||||
print(f"Exception processing job {job_info}: {e}\n{traceback.format_exc()}")
|
||||
outputs = []
|
||||
raise
|
||||
finally:
|
||||
job_info.finished = True
|
||||
session_info.finished_jobs[func_key] = session_info.jobs.pop(func_key)
|
||||
self._release_job_token(job_info.job_token)
|
||||
|
||||
# Filter the function output for any removed outputs
|
||||
filtered_output = []
|
||||
@ -241,11 +312,6 @@ class JobManager:
|
||||
if idx not in job_info.removed_output_idxs:
|
||||
filtered_output.append(output)
|
||||
|
||||
job_info.finished = True
|
||||
session_info.finished_jobs[func_key] = session_info.jobs.pop(func_key)
|
||||
|
||||
self._release_job_token(job_info.job_token)
|
||||
|
||||
# The wrapper added a dummy JSON output. Append a random text string
|
||||
# to fire the dummy objects 'change' event to notify that the job is done
|
||||
filtered_output.append(triggerChangeEvent())
|
||||
@ -254,12 +320,16 @@ class JobManager:
|
||||
|
||||
def _post_call_func(
|
||||
self, func_key: FuncKey, output_dummy_obj: Component, refresh_btn: gr.Button, stop_btn: gr.Button,
|
||||
status_text: gr.Textbox, session_key: str) -> List[Component]:
|
||||
status_text: gr.Textbox, active_image: gr.Image, active_refresh_btn: gr.Button, active_stop_btn: gr.Button,
|
||||
session_key: str) -> List[Component]:
|
||||
''' Called when a job completes '''
|
||||
return {output_dummy_obj: triggerChangeEvent(),
|
||||
refresh_btn: gr.Button.update(variant="secondary", value=refresh_btn.value),
|
||||
stop_btn: gr.Button.update(variant="secondary", value=stop_btn.value),
|
||||
status_text: gr.Textbox.update(value="Generation has finished!")
|
||||
status_text: gr.Textbox.update(value="Generation has finished!"),
|
||||
active_refresh_btn: gr.Button.update(variant="secondary", value=active_refresh_btn.value),
|
||||
active_stop_btn: gr.Button.update(variant="secondary", value=active_stop_btn.value),
|
||||
active_image: gr.Image.update(visible=False)
|
||||
}
|
||||
|
||||
def _update_gallery_event(self, func_key: FuncKey, session_key: str) -> List[Component]:
|
||||
@ -270,21 +340,17 @@ class JobManager:
|
||||
if session_info is None or job_info is None:
|
||||
return []
|
||||
|
||||
if job_info.finished:
|
||||
session_info.finished_jobs.pop(func_key)
|
||||
|
||||
return job_info.images
|
||||
|
||||
def _wrap_func(
|
||||
self, func: Callable, inputs: List[Component], outputs: List[Component],
|
||||
refresh_btn: gr.Button = None, stop_btn: gr.Button = None,
|
||||
status_text: Optional[gr.Textbox] = None) -> Tuple[Callable, List[Component]]:
|
||||
def _wrap_func(self, func: Callable, inputs: List[Component],
|
||||
outputs: List[Component],
|
||||
job_ui: JobManagerUi) -> Tuple[Callable, List[Component]]:
|
||||
''' handles JobManageUI's wrap_func'''
|
||||
|
||||
assert gr.context.Context.block is not None, "wrap_func must be called within a 'gr.Blocks' 'with' context"
|
||||
|
||||
# Create a unique key for this job
|
||||
func_key = FuncKey(job_id=uuid.uuid4(), func=func)
|
||||
func_key = FuncKey(job_id=uuid.uuid4().hex, func=func)
|
||||
|
||||
# Create a unique session key (next gradio release can use gr.State, see https://gradio.app/state_in_blocks/)
|
||||
if self._session_key is None:
|
||||
@ -302,31 +368,59 @@ class JobManager:
|
||||
del outputs[idx]
|
||||
break
|
||||
|
||||
# Add the session key to the inputs
|
||||
inputs += [self._session_key]
|
||||
|
||||
# Create dummy objects
|
||||
update_gallery_obj = gr.JSON(visible=False, elem_id="JobManagerDummyObject")
|
||||
update_gallery_obj.change(
|
||||
partial(self._update_gallery_event, func_key),
|
||||
[self._session_key],
|
||||
[gallery_comp]
|
||||
[gallery_comp],
|
||||
queue=False
|
||||
)
|
||||
|
||||
if refresh_btn:
|
||||
refresh_btn.variant = 'secondary'
|
||||
refresh_btn.click(
|
||||
if job_ui._refresh_btn:
|
||||
job_ui._refresh_btn.variant = 'secondary'
|
||||
job_ui._refresh_btn.click(
|
||||
partial(self._refresh_func, func_key),
|
||||
[self._session_key],
|
||||
[update_gallery_obj, status_text]
|
||||
[update_gallery_obj, job_ui._status_text],
|
||||
queue=False
|
||||
)
|
||||
|
||||
if stop_btn:
|
||||
stop_btn.variant = 'secondary'
|
||||
stop_btn.click(
|
||||
if job_ui._stop_btn:
|
||||
job_ui._stop_btn.variant = 'secondary'
|
||||
job_ui._stop_btn.click(
|
||||
partial(self._stop_wrapped_func, func_key),
|
||||
[self._session_key],
|
||||
[status_text]
|
||||
[job_ui._status_text],
|
||||
queue=False
|
||||
)
|
||||
|
||||
if job_ui._active_image and job_ui._active_image_refresh_btn:
|
||||
job_ui._active_image_refresh_btn.click(
|
||||
partial(self._refresh_cur_iter_func, func_key),
|
||||
[self._session_key],
|
||||
[job_ui._active_image, job_ui._status_text],
|
||||
queue=False
|
||||
)
|
||||
|
||||
if job_ui._active_image_stop_btn:
|
||||
job_ui._active_image_stop_btn.click(
|
||||
partial(self._stop_cur_iter_func, func_key),
|
||||
[self._session_key],
|
||||
[job_ui._active_image, job_ui._status_text],
|
||||
queue=False
|
||||
)
|
||||
|
||||
if job_ui._stop_all_session_btn:
|
||||
job_ui._stop_all_session_btn.click(
|
||||
self.stop_all_jobs, [], [],
|
||||
queue=False
|
||||
)
|
||||
|
||||
if job_ui._free_done_sessions_btn:
|
||||
job_ui._free_done_sessions_btn.click(
|
||||
self.clear_all_finished_jobs, [], [],
|
||||
queue=False
|
||||
)
|
||||
|
||||
# (ab)use gr.JSON to forward events.
|
||||
@ -343,7 +437,8 @@ class JobManager:
|
||||
# Since some parameters are optional it makes sense to use the 'dict' return value type, which requires
|
||||
# the Component as a key... so group together the UI components that the event listeners are going to update
|
||||
# to make it easy to append to function calls and outputs
|
||||
job_ui_params = [refresh_btn, stop_btn, status_text]
|
||||
job_ui_params = [job_ui._refresh_btn, job_ui._stop_btn, job_ui._status_text,
|
||||
job_ui._active_image, job_ui._active_image_refresh_btn, job_ui._active_image_stop_btn]
|
||||
job_ui_outputs = [comp for comp in job_ui_params if comp is not None]
|
||||
|
||||
# Here a chain is constructed that will make a 'pre' call, a 'run' call, and a 'post' call,
|
||||
@ -352,44 +447,70 @@ class JobManager:
|
||||
post_call_dummyobj.change(
|
||||
partial(self._post_call_func, func_key, update_gallery_obj, *job_ui_params),
|
||||
[self._session_key],
|
||||
[update_gallery_obj] + job_ui_outputs
|
||||
[update_gallery_obj] + job_ui_outputs,
|
||||
queue=False
|
||||
)
|
||||
|
||||
call_dummyobj = gr.JSON(visible=False, elem_id="JobManagerDummyObject_runCall")
|
||||
call_dummyobj.change(
|
||||
partial(self._call_func, func_key),
|
||||
[self._session_key],
|
||||
outputs + [post_call_dummyobj]
|
||||
outputs + [post_call_dummyobj],
|
||||
queue=False
|
||||
)
|
||||
|
||||
pre_call_dummyobj = gr.JSON(visible=False, elem_id="JobManagerDummyObject_preCall")
|
||||
pre_call_dummyobj.change(
|
||||
partial(self._pre_call_func, func_key, call_dummyobj, *job_ui_params),
|
||||
[self._session_key],
|
||||
[call_dummyobj] + job_ui_outputs
|
||||
[call_dummyobj] + job_ui_outputs,
|
||||
queue=False
|
||||
)
|
||||
|
||||
# Now replace the original function with one that creates a JobInfo and triggers the dummy obj
|
||||
# Add any components that we want the runtime values for
|
||||
added_inputs = [self._session_key, job_ui._rec_steps_checkbox, job_ui._save_rec_steps_to_gallery_chkbx,
|
||||
job_ui._save_rec_steps_to_file_chkbx, job_ui._rec_steps_intrvl_sldr]
|
||||
|
||||
def wrapped_func(*inputs):
|
||||
session_key = inputs[-1]
|
||||
inputs = inputs[:-1]
|
||||
# Now replace the original function with one that creates a JobInfo and triggers the dummy obj
|
||||
def wrapped_func(*wrapped_inputs):
|
||||
# Remove the added_inputs (pop opposite order of list)
|
||||
|
||||
wrapped_inputs = list(wrapped_inputs)
|
||||
rec_steps_interval: int = wrapped_inputs.pop()
|
||||
save_rec_steps_file: bool = wrapped_inputs.pop()
|
||||
save_rec_steps_grid: bool = wrapped_inputs.pop()
|
||||
record_steps_enabled: bool = wrapped_inputs.pop()
|
||||
session_key: str = wrapped_inputs.pop()
|
||||
job_inputs = tuple(wrapped_inputs)
|
||||
|
||||
# Get or create a session for this key
|
||||
session_info = self._sessions.setdefault(session_key, SessionInfo())
|
||||
|
||||
# Is this session already running this job?
|
||||
if func_key in session_info.jobs:
|
||||
return {status_text: "This session is already running that function!"}
|
||||
job_info = session_info.jobs[func_key]
|
||||
# If the job seems stuck in 'starting' then go ahead and toss it
|
||||
if not job_info.started and time.time() > job_info.timestamp + JobManager.JOB_MAX_START_TIME:
|
||||
job_info.should_stop.set()
|
||||
job_info.stop_cur_iter.set()
|
||||
session_info.jobs.pop(func_key)
|
||||
return {job_ui._status_text: "Canceled possibly hung job. Try again"}
|
||||
return {job_ui._status_text: "This session is already running that function!"}
|
||||
|
||||
# Is this a new run of a previously finished job? Clear old info
|
||||
if func_key in session_info.finished_jobs:
|
||||
session_info.finished_jobs.pop(func_key)
|
||||
|
||||
job_token = self._get_job_token(block=False)
|
||||
job = JobInfo(inputs=inputs, func=func, removed_output_idxs=removed_idxs, session_key=session_key,
|
||||
job_token=job_token)
|
||||
job = JobInfo(
|
||||
inputs=job_inputs, func=func, removed_output_idxs=removed_idxs, session_key=session_key,
|
||||
job_token=job_token, rec_steps_enabled=record_steps_enabled, rec_steps_intrvl=rec_steps_interval,
|
||||
rec_steps_to_gallery=save_rec_steps_grid, rec_steps_to_file=save_rec_steps_file, timestamp=time.time())
|
||||
session_info.jobs[func_key] = job
|
||||
|
||||
ret = {pre_call_dummyobj: triggerChangeEvent()}
|
||||
if job_token is None:
|
||||
ret[status_text] = "Job is queued"
|
||||
ret[job_ui._status_text] = "Job is queued"
|
||||
return ret
|
||||
|
||||
return wrapped_func, inputs, [pre_call_dummyobj, status_text]
|
||||
return wrapped_func, inputs + added_inputs, [pre_call_dummyobj, job_ui._status_text]
|
||||
|
@ -9,10 +9,10 @@ import re
|
||||
def change_image_editor_mode(choice, cropped_image, masked_image, resize_mode, width, height):
|
||||
if choice == "Mask":
|
||||
update_image_result = update_image_mask(cropped_image, resize_mode, width, height)
|
||||
return [gr.update(visible=False), update_image_result, gr.update(visible=False), gr.update(visible=True), gr.update(visible=False), gr.update(visible=True), gr.update(visible=True)]
|
||||
return [gr.update(visible=False), update_image_result, gr.update(visible=False), gr.update(visible=True), gr.update(visible=False), gr.update(visible=True), gr.update(visible=True), gr.update(visible=True)]
|
||||
|
||||
update_image_result = update_image_mask(masked_image["image"] if masked_image is not None else None, resize_mode, width, height)
|
||||
return [update_image_result, gr.update(visible=False), gr.update(visible=True), gr.update(visible=False), gr.update(visible=True), gr.update(visible=False), gr.update(visible=False)]
|
||||
return [update_image_result, gr.update(visible=False), gr.update(visible=True), gr.update(visible=False), gr.update(visible=True), gr.update(visible=False), gr.update(visible=False), gr.update(visible=False)]
|
||||
|
||||
def update_image_mask(cropped_image, resize_mode, width, height):
|
||||
resized_cropped_image = resize_image(resize_mode, cropped_image, width, height) if cropped_image else None
|
||||
|
BIN
images/nsfw.jpeg
Normal file
BIN
images/nsfw.jpeg
Normal file
Binary file not shown.
After Width: | Height: | Size: 25 KiB |
@ -7,6 +7,8 @@ from einops import rearrange, repeat
|
||||
|
||||
from ldm.modules.diffusionmodules.util import checkpoint
|
||||
|
||||
import psutil
|
||||
|
||||
|
||||
def exists(val):
|
||||
return val is not None
|
||||
@ -167,30 +169,98 @@ class CrossAttention(nn.Module):
|
||||
nn.Dropout(dropout)
|
||||
)
|
||||
|
||||
if torch.cuda.is_available():
|
||||
self.einsum_op = self.einsum_op_cuda
|
||||
else:
|
||||
self.mem_total = psutil.virtual_memory().total / (1024**3)
|
||||
self.einsum_op = self.einsum_op_mps_v1 if self.mem_total >= 32 else self.einsum_op_mps_v2
|
||||
|
||||
def einsum_op_compvis(self, q, k, v, r1):
|
||||
s1 = einsum('b i d, b j d -> b i j', q, k) * self.scale # faster
|
||||
s2 = s1.softmax(dim=-1, dtype=q.dtype)
|
||||
del s1
|
||||
r1 = einsum('b i j, b j d -> b i d', s2, v)
|
||||
del s2
|
||||
return r1
|
||||
|
||||
def einsum_op_mps_v1(self, q, k, v, r1):
|
||||
if q.shape[1] <= 4096: # (512x512) max q.shape[1]: 4096
|
||||
r1 = self.einsum_op_compvis(q, k, v, r1)
|
||||
else:
|
||||
slice_size = math.floor(2**30 / (q.shape[0] * q.shape[1]))
|
||||
for i in range(0, q.shape[1], slice_size):
|
||||
end = i + slice_size
|
||||
s1 = einsum('b i d, b j d -> b i j', q[:, i:end], k) * self.scale
|
||||
s2 = s1.softmax(dim=-1, dtype=r1.dtype)
|
||||
del s1
|
||||
r1[:, i:end] = einsum('b i j, b j d -> b i d', s2, v)
|
||||
del s2
|
||||
return r1
|
||||
|
||||
def einsum_op_mps_v2(self, q, k, v, r1):
|
||||
if self.mem_total >= 8 and q.shape[1] <= 4096:
|
||||
r1 = self.einsum_op_compvis(q, k, v, r1)
|
||||
else:
|
||||
slice_size = 1
|
||||
for i in range(0, q.shape[0], slice_size):
|
||||
end = min(q.shape[0], i + slice_size)
|
||||
s1 = einsum('b i d, b j d -> b i j', q[i:end], k[i:end])
|
||||
s1 *= self.scale
|
||||
s2 = s1.softmax(dim=-1, dtype=r1.dtype)
|
||||
del s1
|
||||
r1[i:end] = einsum('b i j, b j d -> b i d', s2, v[i:end])
|
||||
del s2
|
||||
return r1
|
||||
|
||||
def einsum_op_cuda(self, q, k, v, r1):
|
||||
stats = torch.cuda.memory_stats(q.device)
|
||||
mem_active = stats['active_bytes.all.current']
|
||||
mem_reserved = stats['reserved_bytes.all.current']
|
||||
mem_free_cuda, _ = torch.cuda.mem_get_info(torch.cuda.current_device())
|
||||
mem_free_torch = mem_reserved - mem_active
|
||||
mem_free_total = mem_free_cuda + mem_free_torch
|
||||
|
||||
gb = 1024 ** 3
|
||||
tensor_size = q.shape[0] * q.shape[1] * k.shape[1] * 4
|
||||
mem_required = tensor_size * 2.5
|
||||
steps = 1
|
||||
|
||||
if mem_required > mem_free_total:
|
||||
steps = 2**(math.ceil(math.log(mem_required / mem_free_total, 2)))
|
||||
|
||||
if steps > 64:
|
||||
max_res = math.floor(math.sqrt(math.sqrt(mem_free_total / 2.5)) / 8) * 64
|
||||
raise RuntimeError(f'Not enough memory, use lower resolution (max approx. {max_res}x{max_res}). '
|
||||
f'Need: {mem_required/64/gb:0.1f}GB free, Have:{mem_free_total/gb:0.1f}GB free')
|
||||
|
||||
slice_size = q.shape[1] // steps if (q.shape[1] % steps) == 0 else q.shape[1]
|
||||
for i in range(0, q.shape[1], slice_size):
|
||||
end = min(q.shape[1], i + slice_size)
|
||||
s1 = einsum('b i d, b j d -> b i j', q[:, i:end], k) * self.scale
|
||||
s2 = s1.softmax(dim=-1, dtype=r1.dtype)
|
||||
del s1
|
||||
r1[:, i:end] = einsum('b i j, b j d -> b i d', s2, v)
|
||||
del s2
|
||||
return r1
|
||||
|
||||
def forward(self, x, context=None, mask=None):
|
||||
h = self.heads
|
||||
|
||||
q = self.to_q(x)
|
||||
context = default(context, x)
|
||||
del x
|
||||
k = self.to_k(context)
|
||||
v = self.to_v(context)
|
||||
del context
|
||||
|
||||
q, k, v = map(lambda t: rearrange(t, 'b n (h d) -> (b h) n d', h=h), (q, k, v))
|
||||
|
||||
sim = einsum('b i d, b j d -> b i j', q, k) * self.scale
|
||||
|
||||
if exists(mask):
|
||||
mask = rearrange(mask, 'b ... -> b (...)')
|
||||
max_neg_value = -torch.finfo(sim.dtype).max
|
||||
mask = repeat(mask, 'b j -> (b h) () j', h=h)
|
||||
sim.masked_fill_(~mask, max_neg_value)
|
||||
|
||||
# attention, what we cannot get enough of
|
||||
attn = sim.softmax(dim=-1)
|
||||
|
||||
out = einsum('b i j, b j d -> b i d', attn, v)
|
||||
out = rearrange(out, '(b h) n d -> b n (h d)', h=h)
|
||||
return self.to_out(out)
|
||||
r1 = torch.zeros(q.shape[0], q.shape[1], v.shape[2], device=q.device, dtype=q.dtype)
|
||||
r1 = self.einsum_op(q, k, v, r1)
|
||||
del q, k, v
|
||||
r2 = rearrange(r1, '(b h) n d -> b n (h d)', h=h)
|
||||
del r1
|
||||
return self.to_out(r2)
|
||||
|
||||
|
||||
class BasicTransformerBlock(nn.Module):
|
||||
@ -209,9 +279,10 @@ class BasicTransformerBlock(nn.Module):
|
||||
return checkpoint(self._forward, (x, context), self.parameters(), self.checkpoint)
|
||||
|
||||
def _forward(self, x, context=None):
|
||||
x = self.attn1(self.norm1(x)) + x
|
||||
x = self.attn2(self.norm2(x), context=context) + x
|
||||
x = self.ff(self.norm3(x)) + x
|
||||
x = x.contiguous() if x.device.type == 'mps' else x
|
||||
x += self.attn1(self.norm1(x))
|
||||
x += self.attn2(self.norm2(x), context=context)
|
||||
x += self.ff(self.norm3(x))
|
||||
return x
|
||||
|
||||
|
||||
|
@ -1,4 +1,5 @@
|
||||
# pytorch_diffusion + derived encoder decoder
|
||||
import gc
|
||||
import math
|
||||
import torch
|
||||
import torch.nn as nn
|
||||
@ -119,18 +120,30 @@ class ResnetBlock(nn.Module):
|
||||
padding=0)
|
||||
|
||||
def forward(self, x, temb):
|
||||
h = x
|
||||
h = self.norm1(h)
|
||||
h = nonlinearity(h)
|
||||
h = self.conv1(h)
|
||||
h1 = x
|
||||
h2 = self.norm1(h1)
|
||||
del h1
|
||||
|
||||
h3 = nonlinearity(h2)
|
||||
del h2
|
||||
|
||||
h4 = self.conv1(h3)
|
||||
del h3
|
||||
|
||||
if temb is not None:
|
||||
h = h + self.temb_proj(nonlinearity(temb))[:,:,None,None]
|
||||
h4 = h4 + self.temb_proj(nonlinearity(temb))[:,:,None,None]
|
||||
|
||||
h = self.norm2(h)
|
||||
h = nonlinearity(h)
|
||||
h = self.dropout(h)
|
||||
h = self.conv2(h)
|
||||
h5 = self.norm2(h4)
|
||||
del h4
|
||||
|
||||
h6 = nonlinearity(h5)
|
||||
del h5
|
||||
|
||||
h7 = self.dropout(h6)
|
||||
del h6
|
||||
|
||||
h8 = self.conv2(h7)
|
||||
del h7
|
||||
|
||||
if self.in_channels != self.out_channels:
|
||||
if self.use_conv_shortcut:
|
||||
@ -138,7 +151,7 @@ class ResnetBlock(nn.Module):
|
||||
else:
|
||||
x = self.nin_shortcut(x)
|
||||
|
||||
return x+h
|
||||
return x + h8
|
||||
|
||||
|
||||
class LinAttnBlock(LinearAttention):
|
||||
@ -178,28 +191,65 @@ class AttnBlock(nn.Module):
|
||||
def forward(self, x):
|
||||
h_ = x
|
||||
h_ = self.norm(h_)
|
||||
q = self.q(h_)
|
||||
k = self.k(h_)
|
||||
q1 = self.q(h_)
|
||||
k1 = self.k(h_)
|
||||
v = self.v(h_)
|
||||
|
||||
# compute attention
|
||||
b,c,h,w = q.shape
|
||||
q = q.reshape(b,c,h*w)
|
||||
q = q.permute(0,2,1) # b,hw,c
|
||||
k = k.reshape(b,c,h*w) # b,c,hw
|
||||
w_ = torch.bmm(q,k) # b,hw,hw w[b,i,j]=sum_c q[b,i,c]k[b,c,j]
|
||||
w_ = w_ * (int(c)**(-0.5))
|
||||
w_ = torch.nn.functional.softmax(w_, dim=2)
|
||||
b, c, h, w = q1.shape
|
||||
|
||||
# attend to values
|
||||
v = v.reshape(b,c,h*w)
|
||||
w_ = w_.permute(0,2,1) # b,hw,hw (first hw of k, second of q)
|
||||
h_ = torch.bmm(v,w_) # b, c,hw (hw of q) h_[b,c,j] = sum_i v[b,c,i] w_[b,i,j]
|
||||
h_ = h_.reshape(b,c,h,w)
|
||||
q2 = q1.reshape(b, c, h*w)
|
||||
del q1
|
||||
|
||||
h_ = self.proj_out(h_)
|
||||
q = q2.permute(0, 2, 1) # b,hw,c
|
||||
del q2
|
||||
|
||||
return x+h_
|
||||
k = k1.reshape(b, c, h*w) # b,c,hw
|
||||
del k1
|
||||
|
||||
h_ = torch.zeros_like(k, device=q.device)
|
||||
|
||||
stats = torch.cuda.memory_stats(q.device)
|
||||
mem_active = stats['active_bytes.all.current']
|
||||
mem_reserved = stats['reserved_bytes.all.current']
|
||||
mem_free_cuda, _ = torch.cuda.mem_get_info(torch.cuda.current_device())
|
||||
mem_free_torch = mem_reserved - mem_active
|
||||
mem_free_total = mem_free_cuda + mem_free_torch
|
||||
|
||||
tensor_size = q.shape[0] * q.shape[1] * k.shape[2] * 4
|
||||
mem_required = tensor_size * 2.5
|
||||
steps = 1
|
||||
|
||||
if mem_required > mem_free_total:
|
||||
steps = 2**(math.ceil(math.log(mem_required / mem_free_total, 2)))
|
||||
|
||||
slice_size = q.shape[1] // steps if (q.shape[1] % steps) == 0 else q.shape[1]
|
||||
for i in range(0, q.shape[1], slice_size):
|
||||
end = i + slice_size
|
||||
|
||||
w1 = torch.bmm(q[:, i:end], k) # b,hw,hw w[b,i,j]=sum_c q[b,i,c]k[b,c,j]
|
||||
w2 = w1 * (int(c)**(-0.5))
|
||||
del w1
|
||||
w3 = torch.nn.functional.softmax(w2, dim=2)
|
||||
del w2
|
||||
|
||||
# attend to values
|
||||
v1 = v.reshape(b, c, h*w)
|
||||
w4 = w3.permute(0, 2, 1) # b,hw,hw (first hw of k, second of q)
|
||||
del w3
|
||||
|
||||
h_[:, :, i:end] = torch.bmm(v1, w4) # b, c,hw (hw of q) h_[b,c,j] = sum_i v[b,c,i] w_[b,i,j]
|
||||
del v1, w4
|
||||
|
||||
h2 = h_.reshape(b, c, h, w)
|
||||
del h_
|
||||
|
||||
h3 = self.proj_out(h2)
|
||||
del h2
|
||||
|
||||
h3 += x
|
||||
|
||||
return h3
|
||||
|
||||
|
||||
def make_attn(in_channels, attn_type="vanilla"):
|
||||
@ -540,31 +590,54 @@ class Decoder(nn.Module):
|
||||
temb = None
|
||||
|
||||
# z to block_in
|
||||
h = self.conv_in(z)
|
||||
h1 = self.conv_in(z)
|
||||
|
||||
# middle
|
||||
h = self.mid.block_1(h, temb)
|
||||
h = self.mid.attn_1(h)
|
||||
h = self.mid.block_2(h, temb)
|
||||
h2 = self.mid.block_1(h1, temb)
|
||||
del h1
|
||||
|
||||
h3 = self.mid.attn_1(h2)
|
||||
del h2
|
||||
|
||||
h = self.mid.block_2(h3, temb)
|
||||
del h3
|
||||
|
||||
# prepare for up sampling
|
||||
gc.collect()
|
||||
torch.cuda.empty_cache()
|
||||
|
||||
# upsampling
|
||||
for i_level in reversed(range(self.num_resolutions)):
|
||||
for i_block in range(self.num_res_blocks+1):
|
||||
h = self.up[i_level].block[i_block](h, temb)
|
||||
if len(self.up[i_level].attn) > 0:
|
||||
h = self.up[i_level].attn[i_block](h)
|
||||
t = h
|
||||
h = self.up[i_level].attn[i_block](t)
|
||||
del t
|
||||
|
||||
if i_level != 0:
|
||||
h = self.up[i_level].upsample(h)
|
||||
t = h
|
||||
h = self.up[i_level].upsample(t)
|
||||
del t
|
||||
|
||||
# end
|
||||
if self.give_pre_end:
|
||||
return h
|
||||
|
||||
h = self.norm_out(h)
|
||||
h = nonlinearity(h)
|
||||
h = self.conv_out(h)
|
||||
h1 = self.norm_out(h)
|
||||
del h
|
||||
|
||||
h2 = nonlinearity(h1)
|
||||
del h1
|
||||
|
||||
h = self.conv_out(h2)
|
||||
del h2
|
||||
|
||||
if self.tanh_out:
|
||||
h = torch.tanh(h)
|
||||
t = h
|
||||
h = torch.tanh(t)
|
||||
del t
|
||||
|
||||
return h
|
||||
|
||||
|
||||
|
@ -54,7 +54,8 @@ def make_ddim_timesteps(ddim_discr_method, num_ddim_timesteps, num_ddpm_timestep
|
||||
|
||||
# assert ddim_timesteps.shape[0] == num_ddim_timesteps
|
||||
# add one to get the final alpha values right (the ones from first scale to data during sampling)
|
||||
steps_out = ddim_timesteps + 1
|
||||
# steps_out = ddim_timesteps + 1 # removed due to some issues when reaching 1000
|
||||
steps_out = np.where(ddim_timesteps != 999, ddim_timesteps+1, ddim_timesteps)
|
||||
if verbose:
|
||||
print(f'Selected timesteps for ddim sampler: {steps_out}')
|
||||
return steps_out
|
||||
|
1312
scripts/DeforumStableDiffusion.py
Normal file
1312
scripts/DeforumStableDiffusion.py
Normal file
File diff suppressed because it is too large
Load Diff
46
scripts/ModelManager.py
Normal file
46
scripts/ModelManager.py
Normal file
@ -0,0 +1,46 @@
|
||||
# base webui import and utils.
|
||||
from webui_streamlit import st
|
||||
from sd_utils import *
|
||||
|
||||
# streamlit imports
|
||||
|
||||
|
||||
#other imports
|
||||
import pandas as pd
|
||||
from io import StringIO
|
||||
|
||||
# Temp imports
|
||||
|
||||
|
||||
# end of imports
|
||||
#---------------------------------------------------------------------------------------------------------------
|
||||
|
||||
def layout():
|
||||
#search = st.text_input(label="Search", placeholder="Type the name of the model you want to search for.", help="")
|
||||
|
||||
csvString = f"""
|
||||
,Stable Diffusion v1.4 , ./models/ldm/stable-diffusion-v1 , https://www.googleapis.com/storage/v1/b/aai-blog-files/o/sd-v1-4.ckpt?alt=media
|
||||
,GFPGAN v1.3 , ./src/gfpgan/experiments/pretrained_models , https://github.com/TencentARC/GFPGAN/releases/download/v1.3.0/GFPGANv1.3.pth
|
||||
,RealESRGAN_x4plus , ./src/realesrgan/experiments/pretrained_models , https://github.com/xinntao/Real-ESRGAN/releases/download/v0.1.0/RealESRGAN_x4plus.pth
|
||||
,RealESRGAN_x4plus_anime_6B , ./src/realesrgan/experiments/pretrained_models , https://github.com/xinntao/Real-ESRGAN/releases/download/v0.2.2.4/RealESRGAN_x4plus_anime_6B.pth
|
||||
,Waifu Diffusion v1.2 , ./models/custom , http://wd.links.sd:8880/wd-v1-2-full-ema.ckpt
|
||||
,TrinArt Stable Diffusion v2 , ./models/custom , https://huggingface.co/naclbit/trinart_stable_diffusion_v2/resolve/main/trinart2_step115000.ckpt
|
||||
,Stable Diffusion Concept Library , ./models/customsd-concepts-library , https://github.com/sd-webui/sd-concepts-library
|
||||
"""
|
||||
colms = st.columns((1, 3, 5, 5))
|
||||
columns = ["№",'Model Name','Save Location','Download Link']
|
||||
|
||||
# Convert String into StringIO
|
||||
csvStringIO = StringIO(csvString)
|
||||
df = pd.read_csv(csvStringIO, sep=",", header=None, names=columns)
|
||||
|
||||
for col, field_name in zip(colms, columns):
|
||||
# table header
|
||||
col.write(field_name)
|
||||
|
||||
for x, model_name in enumerate(df["Model Name"]):
|
||||
col1, col2, col3, col4 = st.columns((1, 3, 4, 6))
|
||||
col1.write(x) # index
|
||||
col2.write(df['Model Name'][x])
|
||||
col3.write(df['Save Location'][x])
|
||||
col4.write(df['Download Link'][x])
|
5
scripts/Settings.py
Normal file
5
scripts/Settings.py
Normal file
@ -0,0 +1,5 @@
|
||||
from webui_streamlit import st
|
||||
|
||||
# The global settings section will be moved to the Settings page.
|
||||
#with st.expander("Global Settings:"):
|
||||
st.write("Global Settings:")
|
216
scripts/home.py
Normal file
216
scripts/home.py
Normal file
@ -0,0 +1,216 @@
|
||||
# base webui import and utils.
|
||||
from webui_streamlit import st
|
||||
from sd_utils import *
|
||||
|
||||
# streamlit imports
|
||||
|
||||
|
||||
#other imports
|
||||
|
||||
# Temp imports
|
||||
|
||||
|
||||
# end of imports
|
||||
#---------------------------------------------------------------------------------------------------------------
|
||||
|
||||
import os
|
||||
from PIL import Image
|
||||
|
||||
try:
|
||||
# this silences the annoying "Some weights of the model checkpoint were not used when initializing..." message at start.
|
||||
from transformers import logging
|
||||
|
||||
logging.set_verbosity_error()
|
||||
except:
|
||||
pass
|
||||
|
||||
class plugin_info():
|
||||
plugname = "home"
|
||||
description = "Home"
|
||||
isTab = True
|
||||
displayPriority = 0
|
||||
|
||||
def getLatestGeneratedImagesFromPath():
|
||||
#get the latest images from the generated images folder
|
||||
#get the path to the generated images folder
|
||||
generatedImagesPath = os.path.join(os.getcwd(),'outputs')
|
||||
#get all the files from the folders and subfolders
|
||||
files = []
|
||||
#get the latest 10 images from the output folder without walking the subfolders
|
||||
for r, d, f in os.walk(generatedImagesPath):
|
||||
for file in f:
|
||||
if '.png' in file:
|
||||
files.append(os.path.join(r, file))
|
||||
#sort the files by date
|
||||
files.sort(reverse=True, key=os.path.getmtime)
|
||||
latest = files[:90]
|
||||
latest.reverse()
|
||||
|
||||
# reverse the list so the latest images are first and truncate to
|
||||
# a reasonable number of images, 10 pages worth
|
||||
return [Image.open(f) for f in latest]
|
||||
|
||||
def get_images_from_lexica():
|
||||
#scrape images from lexica.art
|
||||
#get the html from the page
|
||||
#get the html with cookies and javascript
|
||||
apiEndpoint = r'https://lexica.art/api/trpc/prompts.infinitePrompts?batch=1&input=%7B%220%22%3A%7B%22json%22%3A%7B%22limit%22%3A10%2C%22text%22%3A%22%22%2C%22cursor%22%3A10%7D%7D%7D'
|
||||
#REST API call
|
||||
#
|
||||
from requests_html import HTMLSession
|
||||
session = HTMLSession()
|
||||
|
||||
response = session.get(apiEndpoint)
|
||||
#req = requests.Session()
|
||||
#req.headers['user-agent'] = 'Mozilla/5.0 (X11; Linux x86_64) AppleWebKit/537.36 (KHTML, like Gecko) Chrome/96.0.4664.45 Safari/537.36'
|
||||
#response = req.get(apiEndpoint)
|
||||
print(response.status_code)
|
||||
print(response.text)
|
||||
#get the json from the response
|
||||
#json = response.json()
|
||||
#get the prompts from the json
|
||||
print(response)
|
||||
#session = requests.Session()
|
||||
#parseEndpointJson = session.get(apiEndpoint,headers=headers,verify=False)
|
||||
#print(parseEndpointJson)
|
||||
#print('test2')
|
||||
#page = requests.get("https://lexica.art/", headers={'User-Agent': 'Mozilla/5.0'})
|
||||
#parse the html
|
||||
#soup = BeautifulSoup(page.content, 'html.parser')
|
||||
#find all the images
|
||||
#print(soup)
|
||||
#images = soup.find_all('alt-image')
|
||||
#create a list to store the image urls
|
||||
image_urls = []
|
||||
#loop through the images
|
||||
for image in images:
|
||||
#get the url
|
||||
image_url = image['src']
|
||||
#add it to the list
|
||||
image_urls.append('http://www.lexica.art/'+image_url)
|
||||
#return the list
|
||||
print(image_urls)
|
||||
return image_urls
|
||||
|
||||
def layout():
|
||||
#streamlit home page layout
|
||||
#center the title
|
||||
st.markdown("<h1 style='text-align: center; color: white;'>Welcome, let's make some 🎨</h1>", unsafe_allow_html=True)
|
||||
#make a gallery of images
|
||||
#st.markdown("<h2 style='text-align: center; color: white;'>Gallery</h2>", unsafe_allow_html=True)
|
||||
#create a gallery of images using columns
|
||||
#col1, col2, col3 = st.columns(3)
|
||||
#load the images
|
||||
#create 3 columns
|
||||
# create a tab for the gallery
|
||||
#st.markdown("<h2 style='text-align: center; color: white;'>Gallery</h2>", unsafe_allow_html=True)
|
||||
#st.markdown("<h2 style='text-align: center; color: white;'>Gallery</h2>", unsafe_allow_html=True)
|
||||
|
||||
history_tab, discover_tabs = st.tabs(["History","Discover"])
|
||||
|
||||
latestImages = getLatestGeneratedImagesFromPath()
|
||||
st.session_state['latestImages'] = latestImages
|
||||
|
||||
with history_tab:
|
||||
##---------------------------------------------------------
|
||||
## image slideshow test
|
||||
## Number of entries per screen
|
||||
#slideshow_N = 9
|
||||
#slideshow_page_number = 0
|
||||
#slideshow_last_page = len(latestImages) // slideshow_N
|
||||
|
||||
## Add a next button and a previous button
|
||||
|
||||
#slideshow_prev, slideshow_image_col , slideshow_next = st.columns([1, 10, 1])
|
||||
|
||||
#with slideshow_image_col:
|
||||
#slideshow_image = st.empty()
|
||||
|
||||
#slideshow_image.image(st.session_state['latestImages'][0])
|
||||
|
||||
#current_image = 0
|
||||
|
||||
#if slideshow_next.button("Next", key=1):
|
||||
##print (current_image+1)
|
||||
#current_image = current_image+1
|
||||
#slideshow_image.image(st.session_state['latestImages'][current_image+1])
|
||||
#if slideshow_prev.button("Previous", key=0):
|
||||
##print ([current_image-1])
|
||||
#current_image = current_image-1
|
||||
#slideshow_image.image(st.session_state['latestImages'][current_image - 1])
|
||||
|
||||
|
||||
#---------------------------------------------------------
|
||||
|
||||
# image gallery
|
||||
# Number of entries per screen
|
||||
gallery_N = 9
|
||||
if not "galleryPage" in st.session_state:
|
||||
st.session_state["galleryPage"] = 0
|
||||
gallery_last_page = len(latestImages) // gallery_N
|
||||
|
||||
# Add a next button and a previous button
|
||||
|
||||
gallery_prev, gallery_refresh, gallery_pagination , gallery_next = st.columns([2, 2, 8, 1])
|
||||
|
||||
# the pagination doesnt work for now so its better to disable the buttons.
|
||||
|
||||
if gallery_refresh.button("Refresh", key=4):
|
||||
st.session_state["galleryPage"] = 0
|
||||
|
||||
if gallery_next.button("Next", key=3):
|
||||
|
||||
if st.session_state["galleryPage"] + 1 > gallery_last_page:
|
||||
st.session_state["galleryPage"] = 0
|
||||
else:
|
||||
st.session_state["galleryPage"] += 1
|
||||
|
||||
if gallery_prev.button("Previous", key=2):
|
||||
|
||||
if st.session_state["galleryPage"] - 1 < 0:
|
||||
st.session_state["galleryPage"] = gallery_last_page
|
||||
else:
|
||||
st.session_state["galleryPage"] -= 1
|
||||
|
||||
print(st.session_state["galleryPage"])
|
||||
# Get start and end indices of the next page of the dataframe
|
||||
gallery_start_idx = st.session_state["galleryPage"] * gallery_N
|
||||
gallery_end_idx = (1 + st.session_state["galleryPage"]) * gallery_N
|
||||
|
||||
#---------------------------------------------------------
|
||||
|
||||
placeholder = st.empty()
|
||||
|
||||
#populate the 3 images per column
|
||||
with placeholder.container():
|
||||
col1, col2, col3 = st.columns(3)
|
||||
col1_cont = st.container()
|
||||
col2_cont = st.container()
|
||||
col3_cont = st.container()
|
||||
|
||||
print (len(st.session_state['latestImages']))
|
||||
images = list(reversed(st.session_state['latestImages']))[gallery_start_idx:(gallery_start_idx+gallery_N)]
|
||||
|
||||
with col1_cont:
|
||||
with col1:
|
||||
[st.image(images[index]) for index in [0, 3, 6] if index < len(images)]
|
||||
with col2_cont:
|
||||
with col2:
|
||||
[st.image(images[index]) for index in [1, 4, 7] if index < len(images)]
|
||||
with col3_cont:
|
||||
with col3:
|
||||
[st.image(images[index]) for index in [2, 5, 8] if index < len(images)]
|
||||
|
||||
|
||||
st.session_state['historyTab'] = [history_tab,col1,col2,col3,placeholder,col1_cont,col2_cont,col3_cont]
|
||||
|
||||
with discover_tabs:
|
||||
st.markdown("<h1 style='text-align: center; color: white;'>Soon :)</h1>", unsafe_allow_html=True)
|
||||
|
||||
#display the images
|
||||
#add a button to the gallery
|
||||
#st.markdown("<h2 style='text-align: center; color: white;'>Try it out</h2>", unsafe_allow_html=True)
|
||||
#create a button to the gallery
|
||||
#if st.button("Try it out"):
|
||||
#if the button is clicked, go to the gallery
|
||||
#st.experimental_rerun()
|
592
scripts/img2img.py
Normal file
592
scripts/img2img.py
Normal file
@ -0,0 +1,592 @@
|
||||
# base webui import and utils.
|
||||
from webui_streamlit import st
|
||||
from sd_utils import *
|
||||
|
||||
# streamlit imports
|
||||
from streamlit import StopException
|
||||
|
||||
#other imports
|
||||
import cv2
|
||||
from PIL import Image, ImageOps
|
||||
import torch
|
||||
import k_diffusion as K
|
||||
import numpy as np
|
||||
import time
|
||||
import torch
|
||||
import skimage
|
||||
from ldm.models.diffusion.ddim import DDIMSampler
|
||||
from ldm.models.diffusion.plms import PLMSSampler
|
||||
# Temp imports
|
||||
|
||||
|
||||
# end of imports
|
||||
#---------------------------------------------------------------------------------------------------------------
|
||||
|
||||
|
||||
try:
|
||||
# this silences the annoying "Some weights of the model checkpoint were not used when initializing..." message at start.
|
||||
from transformers import logging
|
||||
|
||||
logging.set_verbosity_error()
|
||||
except:
|
||||
pass
|
||||
|
||||
def img2img(prompt: str = '', init_info: any = None, init_info_mask: any = None, mask_mode: int = 0, mask_blur_strength: int = 3,
|
||||
mask_restore: bool = False, ddim_steps: int = 50, sampler_name: str = 'DDIM',
|
||||
n_iter: int = 1, cfg_scale: float = 7.5, denoising_strength: float = 0.8,
|
||||
seed: int = -1, noise_mode: int = 0, find_noise_steps: str = "", height: int = 512, width: int = 512, resize_mode: int = 0, fp = None,
|
||||
variant_amount: float = None, variant_seed: int = None, ddim_eta:float = 0.0,
|
||||
write_info_files:bool = True, RealESRGAN_model: str = "RealESRGAN_x4plus_anime_6B",
|
||||
separate_prompts:bool = False, normalize_prompt_weights:bool = True,
|
||||
save_individual_images: bool = True, save_grid: bool = True, group_by_prompt: bool = True,
|
||||
save_as_jpg: bool = True, use_GFPGAN: bool = True, use_RealESRGAN: bool = True, loopback: bool = False,
|
||||
random_seed_loopback: bool = False
|
||||
):
|
||||
|
||||
outpath = st.session_state['defaults'].general.outdir_img2img or st.session_state['defaults'].general.outdir or "outputs/img2img-samples"
|
||||
#err = False
|
||||
#loopback = False
|
||||
#skip_save = False
|
||||
seed = seed_to_int(seed)
|
||||
|
||||
batch_size = 1
|
||||
|
||||
#prompt_matrix = 0
|
||||
#normalize_prompt_weights = 1 in toggles
|
||||
#loopback = 2 in toggles
|
||||
#random_seed_loopback = 3 in toggles
|
||||
#skip_save = 4 not in toggles
|
||||
#save_grid = 5 in toggles
|
||||
#sort_samples = 6 in toggles
|
||||
#write_info_files = 7 in toggles
|
||||
#write_sample_info_to_log_file = 8 in toggles
|
||||
#jpg_sample = 9 in toggles
|
||||
#use_GFPGAN = 10 in toggles
|
||||
#use_RealESRGAN = 11 in toggles
|
||||
|
||||
if sampler_name == 'PLMS':
|
||||
sampler = PLMSSampler(st.session_state["model"])
|
||||
elif sampler_name == 'DDIM':
|
||||
sampler = DDIMSampler(st.session_state["model"])
|
||||
elif sampler_name == 'k_dpm_2_a':
|
||||
sampler = KDiffusionSampler(st.session_state["model"],'dpm_2_ancestral')
|
||||
elif sampler_name == 'k_dpm_2':
|
||||
sampler = KDiffusionSampler(st.session_state["model"],'dpm_2')
|
||||
elif sampler_name == 'k_euler_a':
|
||||
sampler = KDiffusionSampler(st.session_state["model"],'euler_ancestral')
|
||||
elif sampler_name == 'k_euler':
|
||||
sampler = KDiffusionSampler(st.session_state["model"],'euler')
|
||||
elif sampler_name == 'k_heun':
|
||||
sampler = KDiffusionSampler(st.session_state["model"],'heun')
|
||||
elif sampler_name == 'k_lms':
|
||||
sampler = KDiffusionSampler(st.session_state["model"],'lms')
|
||||
else:
|
||||
raise Exception("Unknown sampler: " + sampler_name)
|
||||
|
||||
def process_init_mask(init_mask: Image):
|
||||
if init_mask.mode == "RGBA":
|
||||
init_mask = init_mask.convert('RGBA')
|
||||
background = Image.new('RGBA', init_mask.size, (0, 0, 0))
|
||||
init_mask = Image.alpha_composite(background, init_mask)
|
||||
init_mask = init_mask.convert('RGB')
|
||||
return init_mask
|
||||
|
||||
init_img = init_info
|
||||
init_mask = None
|
||||
if mask_mode == 0:
|
||||
if init_info_mask:
|
||||
init_mask = process_init_mask(init_info_mask)
|
||||
elif mask_mode == 1:
|
||||
if init_info_mask:
|
||||
init_mask = process_init_mask(init_info_mask)
|
||||
init_mask = ImageOps.invert(init_mask)
|
||||
elif mask_mode == 2:
|
||||
init_img_transparency = init_img.split()[-1].convert('L')#.point(lambda x: 255 if x > 0 else 0, mode='1')
|
||||
init_mask = init_img_transparency
|
||||
init_mask = init_mask.convert("RGB")
|
||||
init_mask = resize_image(resize_mode, init_mask, width, height)
|
||||
init_mask = init_mask.convert("RGB")
|
||||
|
||||
assert 0. <= denoising_strength <= 1., 'can only work with strength in [0.0, 1.0]'
|
||||
t_enc = int(denoising_strength * ddim_steps)
|
||||
|
||||
if init_mask is not None and (noise_mode == 2 or noise_mode == 3) and init_img is not None:
|
||||
noise_q = 0.99
|
||||
color_variation = 0.0
|
||||
mask_blend_factor = 1.0
|
||||
|
||||
np_init = (np.asarray(init_img.convert("RGB"))/255.0).astype(np.float64) # annoyingly complex mask fixing
|
||||
np_mask_rgb = 1. - (np.asarray(ImageOps.invert(init_mask).convert("RGB"))/255.0).astype(np.float64)
|
||||
np_mask_rgb -= np.min(np_mask_rgb)
|
||||
np_mask_rgb /= np.max(np_mask_rgb)
|
||||
np_mask_rgb = 1. - np_mask_rgb
|
||||
np_mask_rgb_hardened = 1. - (np_mask_rgb < 0.99).astype(np.float64)
|
||||
blurred = skimage.filters.gaussian(np_mask_rgb_hardened[:], sigma=16., channel_axis=2, truncate=32.)
|
||||
blurred2 = skimage.filters.gaussian(np_mask_rgb_hardened[:], sigma=16., channel_axis=2, truncate=32.)
|
||||
#np_mask_rgb_dilated = np_mask_rgb + blurred # fixup mask todo: derive magic constants
|
||||
#np_mask_rgb = np_mask_rgb + blurred
|
||||
np_mask_rgb_dilated = np.clip((np_mask_rgb + blurred2) * 0.7071, 0., 1.)
|
||||
np_mask_rgb = np.clip((np_mask_rgb + blurred) * 0.7071, 0., 1.)
|
||||
|
||||
noise_rgb = get_matched_noise(np_init, np_mask_rgb, noise_q, color_variation)
|
||||
blend_mask_rgb = np.clip(np_mask_rgb_dilated,0.,1.) ** (mask_blend_factor)
|
||||
noised = noise_rgb[:]
|
||||
blend_mask_rgb **= (2.)
|
||||
noised = np_init[:] * (1. - blend_mask_rgb) + noised * blend_mask_rgb
|
||||
|
||||
np_mask_grey = np.sum(np_mask_rgb, axis=2)/3.
|
||||
ref_mask = np_mask_grey < 1e-3
|
||||
|
||||
all_mask = np.ones((height, width), dtype=bool)
|
||||
noised[all_mask,:] = skimage.exposure.match_histograms(noised[all_mask,:]**1., noised[ref_mask,:], channel_axis=1)
|
||||
|
||||
init_img = Image.fromarray(np.clip(noised * 255., 0., 255.).astype(np.uint8), mode="RGB")
|
||||
st.session_state["editor_image"].image(init_img) # debug
|
||||
|
||||
def init():
|
||||
image = init_img.convert('RGB')
|
||||
image = np.array(image).astype(np.float32) / 255.0
|
||||
image = image[None].transpose(0, 3, 1, 2)
|
||||
image = torch.from_numpy(image)
|
||||
|
||||
mask_channel = None
|
||||
if init_mask:
|
||||
alpha = resize_image(resize_mode, init_mask, width // 8, height // 8)
|
||||
mask_channel = alpha.split()[-1]
|
||||
|
||||
mask = None
|
||||
if mask_channel is not None:
|
||||
mask = np.array(mask_channel).astype(np.float32) / 255.0
|
||||
mask = (1 - mask)
|
||||
mask = np.tile(mask, (4, 1, 1))
|
||||
mask = mask[None].transpose(0, 1, 2, 3)
|
||||
mask = torch.from_numpy(mask).to(st.session_state["device"])
|
||||
|
||||
if st.session_state['defaults'].general.optimized:
|
||||
st.session_state.modelFS.to(st.session_state["device"] )
|
||||
|
||||
init_image = 2. * image - 1.
|
||||
init_image = init_image.to(st.session_state["device"])
|
||||
init_latent = (st.session_state["model"] if not st.session_state['defaults'].general.optimized else st.session_state.modelFS).get_first_stage_encoding((st.session_state["model"] if not st.session_state['defaults'].general.optimized else modelFS).encode_first_stage(init_image)) # move to latent space
|
||||
|
||||
if st.session_state['defaults'].general.optimized:
|
||||
mem = torch.cuda.memory_allocated()/1e6
|
||||
st.session_state.modelFS.to("cpu")
|
||||
while(torch.cuda.memory_allocated()/1e6 >= mem):
|
||||
time.sleep(1)
|
||||
|
||||
return init_latent, mask,
|
||||
|
||||
def sample(init_data, x, conditioning, unconditional_conditioning, sampler_name):
|
||||
t_enc_steps = t_enc
|
||||
obliterate = False
|
||||
if ddim_steps == t_enc_steps:
|
||||
t_enc_steps = t_enc_steps - 1
|
||||
obliterate = True
|
||||
|
||||
if sampler_name != 'DDIM':
|
||||
x0, z_mask = init_data
|
||||
|
||||
sigmas = sampler.model_wrap.get_sigmas(ddim_steps)
|
||||
noise = x * sigmas[ddim_steps - t_enc_steps - 1]
|
||||
|
||||
xi = x0 + noise
|
||||
|
||||
# Obliterate masked image
|
||||
if z_mask is not None and obliterate:
|
||||
random = torch.randn(z_mask.shape, device=xi.device)
|
||||
xi = (z_mask * noise) + ((1-z_mask) * xi)
|
||||
|
||||
sigma_sched = sigmas[ddim_steps - t_enc_steps - 1:]
|
||||
model_wrap_cfg = CFGMaskedDenoiser(sampler.model_wrap)
|
||||
samples_ddim = K.sampling.__dict__[f'sample_{sampler.get_sampler_name()}'](model_wrap_cfg, xi, sigma_sched,
|
||||
extra_args={'cond': conditioning, 'uncond': unconditional_conditioning,
|
||||
'cond_scale': cfg_scale, 'mask': z_mask, 'x0': x0, 'xi': xi}, disable=False,
|
||||
callback=generation_callback)
|
||||
else:
|
||||
|
||||
x0, z_mask = init_data
|
||||
|
||||
sampler.make_schedule(ddim_num_steps=ddim_steps, ddim_eta=0.0, verbose=False)
|
||||
z_enc = sampler.stochastic_encode(x0, torch.tensor([t_enc_steps]*batch_size).to(st.session_state["device"] ))
|
||||
|
||||
# Obliterate masked image
|
||||
if z_mask is not None and obliterate:
|
||||
random = torch.randn(z_mask.shape, device=z_enc.device)
|
||||
z_enc = (z_mask * random) + ((1-z_mask) * z_enc)
|
||||
|
||||
# decode it
|
||||
samples_ddim = sampler.decode(z_enc, conditioning, t_enc_steps,
|
||||
unconditional_guidance_scale=cfg_scale,
|
||||
unconditional_conditioning=unconditional_conditioning,
|
||||
z_mask=z_mask, x0=x0)
|
||||
return samples_ddim
|
||||
|
||||
|
||||
|
||||
if loopback:
|
||||
output_images, info = None, None
|
||||
history = []
|
||||
initial_seed = None
|
||||
|
||||
do_color_correction = False
|
||||
try:
|
||||
from skimage import exposure
|
||||
do_color_correction = True
|
||||
except:
|
||||
print("Install scikit-image to perform color correction on loopback")
|
||||
|
||||
for i in range(n_iter):
|
||||
if do_color_correction and i == 0:
|
||||
correction_target = cv2.cvtColor(np.asarray(init_img.copy()), cv2.COLOR_RGB2LAB)
|
||||
|
||||
output_images, seed, info, stats = process_images(
|
||||
outpath=outpath,
|
||||
func_init=init,
|
||||
func_sample=sample,
|
||||
prompt=prompt,
|
||||
seed=seed,
|
||||
sampler_name=sampler_name,
|
||||
save_grid=save_grid,
|
||||
batch_size=1,
|
||||
n_iter=1,
|
||||
steps=ddim_steps,
|
||||
cfg_scale=cfg_scale,
|
||||
width=width,
|
||||
height=height,
|
||||
prompt_matrix=separate_prompts,
|
||||
use_GFPGAN=use_GFPGAN,
|
||||
use_RealESRGAN=use_RealESRGAN, # Forcefully disable upscaling when using loopback
|
||||
realesrgan_model_name=RealESRGAN_model,
|
||||
normalize_prompt_weights=normalize_prompt_weights,
|
||||
save_individual_images=save_individual_images,
|
||||
init_img=init_img,
|
||||
init_mask=init_mask,
|
||||
mask_blur_strength=mask_blur_strength,
|
||||
mask_restore=mask_restore,
|
||||
denoising_strength=denoising_strength,
|
||||
noise_mode=noise_mode,
|
||||
find_noise_steps=find_noise_steps,
|
||||
resize_mode=resize_mode,
|
||||
uses_loopback=loopback,
|
||||
uses_random_seed_loopback=random_seed_loopback,
|
||||
sort_samples=group_by_prompt,
|
||||
write_info_files=write_info_files,
|
||||
jpg_sample=save_as_jpg
|
||||
)
|
||||
|
||||
if initial_seed is None:
|
||||
initial_seed = seed
|
||||
|
||||
input_image = init_img
|
||||
init_img = output_images[0]
|
||||
|
||||
if do_color_correction and correction_target is not None:
|
||||
init_img = Image.fromarray(cv2.cvtColor(exposure.match_histograms(
|
||||
cv2.cvtColor(
|
||||
np.asarray(init_img),
|
||||
cv2.COLOR_RGB2LAB
|
||||
),
|
||||
correction_target,
|
||||
channel_axis=2
|
||||
), cv2.COLOR_LAB2RGB).astype("uint8"))
|
||||
if mask_restore is True and init_mask is not None:
|
||||
color_mask = init_mask.filter(ImageFilter.GaussianBlur(mask_blur_strength))
|
||||
color_mask = color_mask.convert('L')
|
||||
source_image = input_image.convert('RGB')
|
||||
target_image = init_img.convert('RGB')
|
||||
|
||||
init_img = Image.composite(source_image, target_image, color_mask)
|
||||
|
||||
if not random_seed_loopback:
|
||||
seed = seed + 1
|
||||
else:
|
||||
seed = seed_to_int(None)
|
||||
|
||||
denoising_strength = max(denoising_strength * 0.95, 0.1)
|
||||
history.append(init_img)
|
||||
|
||||
output_images = history
|
||||
seed = initial_seed
|
||||
|
||||
else:
|
||||
output_images, seed, info, stats = process_images(
|
||||
outpath=outpath,
|
||||
func_init=init,
|
||||
func_sample=sample,
|
||||
prompt=prompt,
|
||||
seed=seed,
|
||||
sampler_name=sampler_name,
|
||||
save_grid=save_grid,
|
||||
batch_size=batch_size,
|
||||
n_iter=n_iter,
|
||||
steps=ddim_steps,
|
||||
cfg_scale=cfg_scale,
|
||||
width=width,
|
||||
height=height,
|
||||
prompt_matrix=separate_prompts,
|
||||
use_GFPGAN=use_GFPGAN,
|
||||
use_RealESRGAN=use_RealESRGAN,
|
||||
realesrgan_model_name=RealESRGAN_model,
|
||||
normalize_prompt_weights=normalize_prompt_weights,
|
||||
save_individual_images=save_individual_images,
|
||||
init_img=init_img,
|
||||
init_mask=init_mask,
|
||||
mask_blur_strength=mask_blur_strength,
|
||||
denoising_strength=denoising_strength,
|
||||
noise_mode=noise_mode,
|
||||
find_noise_steps=find_noise_steps,
|
||||
mask_restore=mask_restore,
|
||||
resize_mode=resize_mode,
|
||||
uses_loopback=loopback,
|
||||
sort_samples=group_by_prompt,
|
||||
write_info_files=write_info_files,
|
||||
jpg_sample=save_as_jpg
|
||||
)
|
||||
|
||||
del sampler
|
||||
|
||||
return output_images, seed, info, stats
|
||||
|
||||
#
|
||||
|
||||
|
||||
def layout():
|
||||
with st.form("img2img-inputs"):
|
||||
st.session_state["generation_mode"] = "img2img"
|
||||
|
||||
img2img_input_col, img2img_generate_col = st.columns([10,1])
|
||||
with img2img_input_col:
|
||||
#prompt = st.text_area("Input Text","")
|
||||
prompt = st.text_input("Input Text","", placeholder="A corgi wearing a top hat as an oil painting.")
|
||||
|
||||
# Every form must have a submit button, the extra blank spaces is a temp way to align it with the input field. Needs to be done in CSS or some other way.
|
||||
img2img_generate_col.write("")
|
||||
img2img_generate_col.write("")
|
||||
generate_button = img2img_generate_col.form_submit_button("Generate")
|
||||
|
||||
|
||||
# creating the page layout using columns
|
||||
col1_img2img_layout, col2_img2img_layout, col3_img2img_layout = st.columns([1,2,2], gap="small")
|
||||
|
||||
with col1_img2img_layout:
|
||||
# If we have custom models available on the "models/custom"
|
||||
#folder then we show a menu to select which model we want to use, otherwise we use the main model for SD
|
||||
if st.session_state["CustomModel_available"]:
|
||||
st.session_state["custom_model"] = st.selectbox("Custom Model:", st.session_state["custom_models"],
|
||||
index=st.session_state["custom_models"].index(st.session_state['defaults'].general.default_model),
|
||||
help="Select the model you want to use. This option is only available if you have custom models \
|
||||
on your 'models/custom' folder. The model name that will be shown here is the same as the name\
|
||||
the file for the model has on said folder, it is recommended to give the .ckpt file a name that \
|
||||
will make it easier for you to distinguish it from other models. Default: Stable Diffusion v1.4")
|
||||
else:
|
||||
st.session_state["custom_model"] = "Stable Diffusion v1.4"
|
||||
|
||||
|
||||
st.session_state["sampling_steps"] = st.slider("Sampling Steps",
|
||||
value=st.session_state['defaults'].img2img.sampling_steps,
|
||||
min_value=st.session_state['defaults'].img2img.slider_bounds.sampling.lower,
|
||||
max_value=st.session_state['defaults'].img2img.slider_bounds.sampling.upper,
|
||||
step=st.session_state['defaults'].img2img.slider_steps.sampling)
|
||||
|
||||
sampler_name_list = ["k_lms", "k_euler", "k_euler_a", "k_dpm_2", "k_dpm_2_a", "k_heun", "PLMS", "DDIM"]
|
||||
st.session_state["sampler_name"] = st.selectbox("Sampling method",sampler_name_list,
|
||||
index=sampler_name_list.index(st.session_state['defaults'].img2img.sampler_name), help="Sampling method to use.")
|
||||
|
||||
mask_mode_list = ["Mask", "Inverted mask", "Image alpha"]
|
||||
mask_mode = st.selectbox("Mask Mode", mask_mode_list,
|
||||
help="Select how you want your image to be masked.\"Mask\" modifies the image where the mask is white.\n\
|
||||
\"Inverted mask\" modifies the image where the mask is black. \"Image alpha\" modifies the image where the image is transparent."
|
||||
)
|
||||
mask_mode = mask_mode_list.index(mask_mode)
|
||||
|
||||
width = st.slider("Width:", min_value=64, max_value=1024, value=st.session_state['defaults'].img2img.width, step=64)
|
||||
height = st.slider("Height:", min_value=64, max_value=1024, value=st.session_state['defaults'].img2img.height, step=64)
|
||||
seed = st.text_input("Seed:", value=st.session_state['defaults'].img2img.seed, help=" The seed to use, if left blank a random seed will be generated.")
|
||||
noise_mode_list = ["Seed", "Find Noise", "Matched Noise", "Find+Matched Noise"]
|
||||
noise_mode = st.selectbox(
|
||||
"Noise Mode", noise_mode_list,
|
||||
help=""
|
||||
)
|
||||
noise_mode = noise_mode_list.index(noise_mode)
|
||||
find_noise_steps = st.slider("Find Noise Steps", value=100, min_value=1, max_value=500)
|
||||
batch_count = st.slider("Batch count.", min_value=1, max_value=100, value=st.session_state['defaults'].img2img.batch_count, step=1,
|
||||
help="How many iterations or batches of images to generate in total.")
|
||||
|
||||
#
|
||||
with st.expander("Advanced"):
|
||||
separate_prompts = st.checkbox("Create Prompt Matrix.", value=st.session_state['defaults'].img2img.separate_prompts,
|
||||
help="Separate multiple prompts using the `|` character, and get all combinations of them.")
|
||||
normalize_prompt_weights = st.checkbox("Normalize Prompt Weights.", value=st.session_state['defaults'].img2img.normalize_prompt_weights,
|
||||
help="Ensure the sum of all weights add up to 1.0")
|
||||
loopback = st.checkbox("Loopback.", value=st.session_state['defaults'].img2img.loopback, help="Use images from previous batch when creating next batch.")
|
||||
random_seed_loopback = st.checkbox("Random loopback seed.", value=st.session_state['defaults'].img2img.random_seed_loopback, help="Random loopback seed")
|
||||
img2img_mask_restore = st.checkbox("Only modify regenerated parts of image",
|
||||
value=st.session_state['defaults'].img2img.mask_restore,
|
||||
help="Enable to restore the unmasked parts of the image with the input, may not blend as well but preserves detail")
|
||||
save_individual_images = st.checkbox("Save individual images.", value=st.session_state['defaults'].img2img.save_individual_images,
|
||||
help="Save each image generated before any filter or enhancement is applied.")
|
||||
save_grid = st.checkbox("Save grid",value=st.session_state['defaults'].img2img.save_grid, help="Save a grid with all the images generated into a single image.")
|
||||
group_by_prompt = st.checkbox("Group results by prompt", value=st.session_state['defaults'].img2img.group_by_prompt,
|
||||
help="Saves all the images with the same prompt into the same folder. \
|
||||
When using a prompt matrix each prompt combination will have its own folder.")
|
||||
write_info_files = st.checkbox("Write Info file", value=st.session_state['defaults'].img2img.write_info_files,
|
||||
help="Save a file next to the image with informartion about the generation.")
|
||||
save_as_jpg = st.checkbox("Save samples as jpg", value=st.session_state['defaults'].img2img.save_as_jpg, help="Saves the images as jpg instead of png.")
|
||||
|
||||
if st.session_state["GFPGAN_available"]:
|
||||
use_GFPGAN = st.checkbox("Use GFPGAN", value=st.session_state['defaults'].img2img.use_GFPGAN, help="Uses the GFPGAN model to improve faces after the generation.\
|
||||
This greatly improve the quality and consistency of faces but uses extra VRAM. Disable if you need the extra VRAM.")
|
||||
else:
|
||||
use_GFPGAN = False
|
||||
|
||||
if st.session_state["RealESRGAN_available"]:
|
||||
st.session_state["use_RealESRGAN"] = st.checkbox("Use RealESRGAN", value=st.session_state['defaults'].img2img.use_RealESRGAN,
|
||||
help="Uses the RealESRGAN model to upscale the images after the generation.\
|
||||
This greatly improve the quality and lets you have high resolution images but uses extra VRAM. Disable if you need the extra VRAM.")
|
||||
st.session_state["RealESRGAN_model"] = st.selectbox("RealESRGAN model", ["RealESRGAN_x4plus", "RealESRGAN_x4plus_anime_6B"], index=0)
|
||||
else:
|
||||
st.session_state["use_RealESRGAN"] = False
|
||||
st.session_state["RealESRGAN_model"] = "RealESRGAN_x4plus"
|
||||
|
||||
variant_amount = st.slider("Variant Amount:", value=st.session_state['defaults'].img2img.variant_amount, min_value=0.0, max_value=1.0, step=0.01)
|
||||
variant_seed = st.text_input("Variant Seed:", value=st.session_state['defaults'].img2img.variant_seed,
|
||||
help="The seed to use when generating a variant, if left blank a random seed will be generated.")
|
||||
cfg_scale = st.slider("CFG (Classifier Free Guidance Scale):", min_value=1.0, max_value=30.0, value=st.session_state['defaults'].img2img.cfg_scale, step=0.5,
|
||||
help="How strongly the image should follow the prompt.")
|
||||
batch_size = st.slider("Batch size", min_value=1, max_value=100, value=st.session_state['defaults'].img2img.batch_size, step=1,
|
||||
help="How many images are at once in a batch.\
|
||||
It increases the VRAM usage a lot but if you have enough VRAM it can reduce the time it takes to finish \
|
||||
generation as more images are generated at once.\
|
||||
Default: 1")
|
||||
|
||||
st.session_state["denoising_strength"] = st.slider("Denoising Strength:", value=st.session_state['defaults'].img2img.denoising_strength,
|
||||
min_value=0.01, max_value=1.0, step=0.01)
|
||||
|
||||
with st.expander("Preview Settings"):
|
||||
st.session_state["update_preview"] = st.checkbox("Update Image Preview", value=st.session_state['defaults'].img2img.update_preview,
|
||||
help="If enabled the image preview will be updated during the generation instead of at the end. \
|
||||
You can use the Update Preview \Frequency option bellow to customize how frequent it's updated. \
|
||||
By default this is enabled and the frequency is set to 1 step.")
|
||||
|
||||
st.session_state["update_preview_frequency"] = st.text_input("Update Image Preview Frequency", value=st.session_state['defaults'].img2img.update_preview_frequency,
|
||||
help="Frequency in steps at which the the preview image is updated. By default the frequency \
|
||||
is set to 1 step.")
|
||||
|
||||
with col2_img2img_layout:
|
||||
editor_tab = st.tabs(["Editor"])
|
||||
|
||||
editor_image = st.empty()
|
||||
st.session_state["editor_image"] = editor_image
|
||||
|
||||
st.form_submit_button("Refresh")
|
||||
|
||||
masked_image_holder = st.empty()
|
||||
image_holder = st.empty()
|
||||
|
||||
uploaded_images = st.file_uploader(
|
||||
"Upload Image", accept_multiple_files=False, type=["png", "jpg", "jpeg", "webp"],
|
||||
help="Upload an image which will be used for the image to image generation.",
|
||||
)
|
||||
if uploaded_images:
|
||||
image = Image.open(uploaded_images).convert('RGBA')
|
||||
new_img = image.resize((width, height))
|
||||
image_holder.image(new_img)
|
||||
|
||||
mask_holder = st.empty()
|
||||
|
||||
uploaded_masks = st.file_uploader(
|
||||
"Upload Mask", accept_multiple_files=False, type=["png", "jpg", "jpeg", "webp"],
|
||||
help="Upload an mask image which will be used for masking the image to image generation.",
|
||||
)
|
||||
if uploaded_masks:
|
||||
mask = Image.open(uploaded_masks)
|
||||
if mask.mode == "RGBA":
|
||||
mask = mask.convert('RGBA')
|
||||
background = Image.new('RGBA', mask.size, (0, 0, 0))
|
||||
mask = Image.alpha_composite(background, mask)
|
||||
mask = mask.resize((width, height))
|
||||
mask_holder.image(mask)
|
||||
|
||||
if uploaded_images and uploaded_masks:
|
||||
if mask_mode != 2:
|
||||
final_img = new_img.copy()
|
||||
alpha_layer = mask.convert('L')
|
||||
strength = st.session_state["denoising_strength"]
|
||||
if mask_mode == 0:
|
||||
alpha_layer = ImageOps.invert(alpha_layer)
|
||||
alpha_layer = alpha_layer.point(lambda a: a * strength)
|
||||
alpha_layer = ImageOps.invert(alpha_layer)
|
||||
elif mask_mode == 1:
|
||||
alpha_layer = alpha_layer.point(lambda a: a * strength)
|
||||
alpha_layer = ImageOps.invert(alpha_layer)
|
||||
|
||||
final_img.putalpha(alpha_layer)
|
||||
|
||||
with masked_image_holder.container():
|
||||
st.text("Masked Image Preview")
|
||||
st.image(final_img)
|
||||
|
||||
|
||||
with col3_img2img_layout:
|
||||
result_tab = st.tabs(["Result"])
|
||||
|
||||
# create an empty container for the image, progress bar, etc so we can update it later and use session_state to hold them globally.
|
||||
preview_image = st.empty()
|
||||
st.session_state["preview_image"] = preview_image
|
||||
|
||||
#st.session_state["loading"] = st.empty()
|
||||
|
||||
st.session_state["progress_bar_text"] = st.empty()
|
||||
st.session_state["progress_bar"] = st.empty()
|
||||
|
||||
|
||||
message = st.empty()
|
||||
|
||||
#if uploaded_images:
|
||||
#image = Image.open(uploaded_images).convert('RGB')
|
||||
##img_array = np.array(image) # if you want to pass it to OpenCV
|
||||
#new_img = image.resize((width, height))
|
||||
#st.image(new_img, use_column_width=True)
|
||||
|
||||
|
||||
if generate_button:
|
||||
#print("Loading models")
|
||||
# load the models when we hit the generate button for the first time, it wont be loaded after that so dont worry.
|
||||
load_models(False, use_GFPGAN, st.session_state["use_RealESRGAN"], st.session_state["RealESRGAN_model"], st.session_state["CustomModel_available"],
|
||||
st.session_state["custom_model"])
|
||||
|
||||
if uploaded_images:
|
||||
image = Image.open(uploaded_images).convert('RGBA')
|
||||
new_img = image.resize((width, height))
|
||||
#img_array = np.array(image) # if you want to pass it to OpenCV
|
||||
new_mask = None
|
||||
if uploaded_masks:
|
||||
mask = Image.open(uploaded_masks).convert('RGBA')
|
||||
new_mask = mask.resize((width, height))
|
||||
|
||||
try:
|
||||
output_images, seed, info, stats = img2img(prompt=prompt, init_info=new_img, init_info_mask=new_mask, mask_mode=mask_mode,
|
||||
mask_restore=img2img_mask_restore, ddim_steps=st.session_state["sampling_steps"],
|
||||
sampler_name=st.session_state["sampler_name"], n_iter=batch_count,
|
||||
cfg_scale=cfg_scale, denoising_strength=st.session_state["denoising_strength"], variant_seed=variant_seed,
|
||||
seed=seed, noise_mode=noise_mode, find_noise_steps=find_noise_steps, width=width,
|
||||
height=height, variant_amount=variant_amount,
|
||||
ddim_eta=0.0, write_info_files=write_info_files, RealESRGAN_model=st.session_state["RealESRGAN_model"],
|
||||
separate_prompts=separate_prompts, normalize_prompt_weights=normalize_prompt_weights,
|
||||
save_individual_images=save_individual_images, save_grid=save_grid,
|
||||
group_by_prompt=group_by_prompt, save_as_jpg=save_as_jpg, use_GFPGAN=use_GFPGAN,
|
||||
use_RealESRGAN=st.session_state["use_RealESRGAN"] if not loopback else False, loopback=loopback
|
||||
)
|
||||
|
||||
#show a message when the generation is complete.
|
||||
message.success('Render Complete: ' + info + '; Stats: ' + stats, icon="✅")
|
||||
|
||||
except (StopException, KeyError):
|
||||
print(f"Received Streamlit StopException")
|
||||
|
||||
# this will render all the images at the end of the generation but its better if its moved to a second tab inside col2 and shown as a gallery.
|
||||
# use the current col2 first tab to show the preview_img and update it as its generated.
|
||||
#preview_image.image(output_images, width=750)
|
||||
|
||||
#on import run init
|
161
scripts/imglab.py
Normal file
161
scripts/imglab.py
Normal file
@ -0,0 +1,161 @@
|
||||
# base webui import and utils.
|
||||
from webui_streamlit import st
|
||||
from sd_utils import *
|
||||
|
||||
#home plugin
|
||||
import os
|
||||
from PIL import Image
|
||||
#from bs4 import BeautifulSoup
|
||||
from streamlit.runtime.in_memory_file_manager import in_memory_file_manager
|
||||
from streamlit.elements import image as STImage
|
||||
|
||||
# Temp imports
|
||||
|
||||
|
||||
# end of imports
|
||||
#---------------------------------------------------------------------------------------------------------------
|
||||
|
||||
try:
|
||||
# this silences the annoying "Some weights of the model checkpoint were not used when initializing..." message at start.
|
||||
from transformers import logging
|
||||
|
||||
logging.set_verbosity_error()
|
||||
except:
|
||||
pass
|
||||
|
||||
class plugin_info():
|
||||
plugname = "imglab"
|
||||
description = "Image Lab"
|
||||
isTab = True
|
||||
displayPriority = 3
|
||||
|
||||
def getLatestGeneratedImagesFromPath():
|
||||
#get the latest images from the generated images folder
|
||||
#get the path to the generated images folder
|
||||
generatedImagesPath = os.path.join(os.getcwd(),'outputs')
|
||||
#get all the files from the folders and subfolders
|
||||
files = []
|
||||
#get the laest 10 images from the output folder without walking the subfolders
|
||||
for r, d, f in os.walk(generatedImagesPath):
|
||||
for file in f:
|
||||
if '.png' in file:
|
||||
files.append(os.path.join(r, file))
|
||||
#sort the files by date
|
||||
files.sort(key=os.path.getmtime)
|
||||
#reverse the list so the latest images are first
|
||||
for f in files:
|
||||
img = Image.open(f)
|
||||
files[files.index(f)] = img
|
||||
#get the latest 10 files
|
||||
#get all the files with the .png or .jpg extension
|
||||
#sort files by date
|
||||
#get the latest 10 files
|
||||
latestFiles = files[-10:]
|
||||
#reverse the list
|
||||
latestFiles.reverse()
|
||||
return latestFiles
|
||||
|
||||
def getImagesFromLexica():
|
||||
#scrape images from lexica.art
|
||||
#get the html from the page
|
||||
#get the html with cookies and javascript
|
||||
apiEndpoint = r'https://lexica.art/api/trpc/prompts.infinitePrompts?batch=1&input=%7B%220%22%3A%7B%22json%22%3A%7B%22limit%22%3A10%2C%22text%22%3A%22%22%2C%22cursor%22%3A10%7D%7D%7D'
|
||||
#REST API call
|
||||
#
|
||||
from requests_html import HTMLSession
|
||||
session = HTMLSession()
|
||||
|
||||
response = session.get(apiEndpoint)
|
||||
#req = requests.Session()
|
||||
#req.headers['user-agent'] = 'Mozilla/5.0 (X11; Linux x86_64) AppleWebKit/537.36 (KHTML, like Gecko) Chrome/96.0.4664.45 Safari/537.36'
|
||||
#response = req.get(apiEndpoint)
|
||||
print(response.status_code)
|
||||
print(response.text)
|
||||
#get the json from the response
|
||||
#json = response.json()
|
||||
#get the prompts from the json
|
||||
print(response)
|
||||
#session = requests.Session()
|
||||
#parseEndpointJson = session.get(apiEndpoint,headers=headers,verify=False)
|
||||
#print(parseEndpointJson)
|
||||
#print('test2')
|
||||
#page = requests.get("https://lexica.art/", headers={'User-Agent': 'Mozilla/5.0'})
|
||||
#parse the html
|
||||
#soup = BeautifulSoup(page.content, 'html.parser')
|
||||
#find all the images
|
||||
#print(soup)
|
||||
#images = soup.find_all('alt-image')
|
||||
#create a list to store the image urls
|
||||
image_urls = []
|
||||
#loop through the images
|
||||
for image in images:
|
||||
#get the url
|
||||
image_url = image['src']
|
||||
#add it to the list
|
||||
image_urls.append('http://www.lexica.art/'+image_url)
|
||||
#return the list
|
||||
print(image_urls)
|
||||
return image_urls
|
||||
def changeImage():
|
||||
#change the image in the image holder
|
||||
#check if the file is not empty
|
||||
if len(st.session_state['uploaded_file']) > 0:
|
||||
#read the file
|
||||
print('test2')
|
||||
uploaded = st.session_state['uploaded_file'][0].read()
|
||||
#show the image in the image holder
|
||||
st.session_state['previewImg'].empty()
|
||||
st.session_state['previewImg'].image(uploaded,use_column_width=True)
|
||||
def createHTMLGallery(images):
|
||||
html3 = """
|
||||
<div class="gallery-history" style="
|
||||
display: flex;
|
||||
flex-wrap: wrap;
|
||||
align-items: flex-start;">
|
||||
"""
|
||||
mkdwn_array = []
|
||||
for i in images:
|
||||
bImg = i.read()
|
||||
i = Image.save(bImg, 'PNG')
|
||||
width, height = i.size
|
||||
#get random number for the id
|
||||
image_id = "%s" % (str(images.index(i)))
|
||||
(data, mimetype) = STImage._normalize_to_bytes(bImg.getvalue(), width, 'auto')
|
||||
this_file = in_memory_file_manager.add(data, mimetype, image_id)
|
||||
img_str = this_file.url
|
||||
#img_str = 'data:image/png;base64,' + b64encode(image_io.getvalue()).decode('ascii')
|
||||
#get image size
|
||||
|
||||
#make sure the image is not bigger then 150px but keep the aspect ratio
|
||||
if width > 150:
|
||||
height = int(height * (150/width))
|
||||
width = 150
|
||||
if height > 150:
|
||||
width = int(width * (150/height))
|
||||
height = 150
|
||||
|
||||
#mkdwn = f"""<img src="{img_str}" alt="Image" with="200" height="200" />"""
|
||||
mkdwn = f'''<div class="gallery" style="margin: 3px;" >
|
||||
<a href="{img_str}">
|
||||
<img src="{img_str}" alt="Image" width="{width}" height="{height}">
|
||||
</a>
|
||||
</div>
|
||||
'''
|
||||
mkdwn_array.append(mkdwn)
|
||||
html3 += "".join(mkdwn_array)
|
||||
html3 += '</div>'
|
||||
return html3
|
||||
def layout():
|
||||
|
||||
col1, col2 = st.columns(2)
|
||||
with col1:
|
||||
st.session_state['uploaded_file'] = st.file_uploader("Choose an image or images", type=["png", "jpg", "jpeg", "webp"],accept_multiple_files=True,on_change=changeImage)
|
||||
if 'previewImg' not in st.session_state:
|
||||
st.session_state['previewImg'] = st.empty()
|
||||
else:
|
||||
if len(st.session_state['uploaded_file']) > 0:
|
||||
st.session_state['previewImg'].empty()
|
||||
st.session_state['previewImg'].image(st.session_state['uploaded_file'][0],use_column_width=True)
|
||||
else:
|
||||
st.session_state['previewImg'] = st.empty()
|
||||
|
48
scripts/perlin.py
Normal file
48
scripts/perlin.py
Normal file
@ -0,0 +1,48 @@
|
||||
import numpy as np
|
||||
|
||||
def perlin(x, y, seed=0):
|
||||
# permutation table
|
||||
np.random.seed(seed)
|
||||
p = np.arange(256, dtype=int)
|
||||
np.random.shuffle(p)
|
||||
p = np.stack([p, p]).flatten()
|
||||
# coordinates of the top-left
|
||||
xi, yi = x.astype(int), y.astype(int)
|
||||
# internal coordinates
|
||||
xf, yf = x - xi, y - yi
|
||||
# fade factors
|
||||
u, v = fade(xf), fade(yf)
|
||||
# noise components
|
||||
n00 = gradient(p[p[xi] + yi], xf, yf)
|
||||
n01 = gradient(p[p[xi] + yi + 1], xf, yf - 1)
|
||||
n11 = gradient(p[p[xi + 1] + yi + 1], xf - 1, yf - 1)
|
||||
n10 = gradient(p[p[xi + 1] + yi], xf - 1, yf)
|
||||
# combine noises
|
||||
x1 = lerp(n00, n10, u)
|
||||
x2 = lerp(n01, n11, u) # FIX1: I was using n10 instead of n01
|
||||
return lerp(x1, x2, v) # FIX2: I also had to reverse x1 and x2 here
|
||||
|
||||
def lerp(a, b, x):
|
||||
"linear interpolation"
|
||||
return a + x * (b - a)
|
||||
|
||||
def fade(t):
|
||||
"6t^5 - 15t^4 + 10t^3"
|
||||
return 6 * t**5 - 15 * t**4 + 10 * t**3
|
||||
|
||||
def gradient(h, x, y):
|
||||
"grad converts h to the right gradient vector and return the dot product with (x,y)"
|
||||
vectors = np.array([[0, 1], [0, -1], [1, 0], [-1, 0]])
|
||||
g = vectors[h % 4]
|
||||
return g[:, :, 0] * x + g[:, :, 1] * y
|
||||
|
||||
lin = np.linspace(0, 5, 100, endpoint=False)
|
||||
x, y = np.meshgrid(lin, lin)
|
||||
|
||||
|
||||
|
||||
def perlinNoise(height,width,octavesx=5,octavesy=5,seed=None):
|
||||
linx = np.linspace(0,octavesx,width,endpoint=False)
|
||||
liny = np.linspace(0,octavesy,height,endpoint=False)
|
||||
x,y = np.meshgrid(linx,liny)
|
||||
return perlin(x,y,seed=seed)
|
@ -19,6 +19,8 @@ optimized_turbo = False
|
||||
# Creates a public xxxxx.gradio.app share link to allow others to use your interface (requires properly forwarded ports to work correctly)
|
||||
share = False
|
||||
|
||||
# Generate tiling images
|
||||
tiling = False
|
||||
|
||||
# Enter other `--arguments` you wish to use - Must be entered as a `--argument ` syntax
|
||||
additional_arguments = ""
|
||||
@ -37,6 +39,8 @@ if optimized_turbo == True:
|
||||
common_arguments += "--optimized-turbo "
|
||||
if optimized == True:
|
||||
common_arguments += "--optimized "
|
||||
if tiling == True:
|
||||
common_arguments += "--tiling "
|
||||
if share == True:
|
||||
common_arguments += "--share "
|
||||
|
||||
|
1728
scripts/sd_utils.py
Normal file
1728
scripts/sd_utils.py
Normal file
File diff suppressed because it is too large
Load Diff
233
scripts/stable_diffusion_pipeline.py
Normal file
233
scripts/stable_diffusion_pipeline.py
Normal file
@ -0,0 +1,233 @@
|
||||
import inspect
|
||||
import warnings
|
||||
from tqdm.auto import tqdm
|
||||
from typing import List, Optional, Union
|
||||
|
||||
import torch
|
||||
from diffusers import ModelMixin
|
||||
from diffusers.models import AutoencoderKL, UNet2DConditionModel
|
||||
from diffusers.pipeline_utils import DiffusionPipeline
|
||||
from diffusers.pipelines.stable_diffusion.safety_checker import \
|
||||
StableDiffusionSafetyChecker
|
||||
from diffusers.schedulers import (DDIMScheduler, LMSDiscreteScheduler,
|
||||
PNDMScheduler)
|
||||
from transformers import CLIPFeatureExtractor, CLIPTextModel, CLIPTokenizer
|
||||
|
||||
|
||||
class StableDiffusionPipeline(DiffusionPipeline):
|
||||
def __init__(
|
||||
self,
|
||||
vae: AutoencoderKL,
|
||||
text_encoder: CLIPTextModel,
|
||||
tokenizer: CLIPTokenizer,
|
||||
unet: UNet2DConditionModel,
|
||||
scheduler: Union[DDIMScheduler, PNDMScheduler, LMSDiscreteScheduler],
|
||||
safety_checker: StableDiffusionSafetyChecker,
|
||||
feature_extractor: CLIPFeatureExtractor,
|
||||
):
|
||||
super().__init__()
|
||||
scheduler = scheduler.set_format("pt")
|
||||
self.register_modules(
|
||||
vae=vae,
|
||||
text_encoder=text_encoder,
|
||||
tokenizer=tokenizer,
|
||||
unet=unet,
|
||||
scheduler=scheduler,
|
||||
safety_checker=safety_checker,
|
||||
feature_extractor=feature_extractor,
|
||||
)
|
||||
|
||||
@torch.no_grad()
|
||||
def __call__(
|
||||
self,
|
||||
prompt: Optional[Union[str, List[str]]] = None,
|
||||
height: Optional[int] = 512,
|
||||
width: Optional[int] = 512,
|
||||
num_inference_steps: Optional[int] = 50,
|
||||
guidance_scale: Optional[float] = 7.5,
|
||||
eta: Optional[float] = 0.0,
|
||||
generator: Optional[torch.Generator] = None,
|
||||
latents: Optional[torch.FloatTensor] = None,
|
||||
text_embeddings: Optional[torch.FloatTensor] = None,
|
||||
output_type: Optional[str] = "pil",
|
||||
**kwargs,
|
||||
):
|
||||
if "torch_device" in kwargs:
|
||||
device = kwargs.pop("torch_device")
|
||||
warnings.warn(
|
||||
"`torch_device` is deprecated as an input argument to `__call__` and"
|
||||
" will be removed in v0.3.0. Consider using `pipe.to(torch_device)`"
|
||||
" instead."
|
||||
)
|
||||
|
||||
# Set device as before (to be removed in 0.3.0)
|
||||
if device is None:
|
||||
device = "cuda" if torch.cuda.is_available() else "cpu"
|
||||
self.to(device)
|
||||
|
||||
if text_embeddings is None:
|
||||
if isinstance(prompt, str):
|
||||
batch_size = 1
|
||||
elif isinstance(prompt, list):
|
||||
batch_size = len(prompt)
|
||||
else:
|
||||
raise ValueError(
|
||||
f"`prompt` has to be of type `str` or `list` but is {type(prompt)}"
|
||||
)
|
||||
|
||||
if height % 8 != 0 or width % 8 != 0:
|
||||
raise ValueError(
|
||||
"`height` and `width` have to be divisible by 8 but are"
|
||||
f" {height} and {width}."
|
||||
)
|
||||
|
||||
# get prompt text embeddings
|
||||
text_input = self.tokenizer(
|
||||
prompt,
|
||||
padding="max_length",
|
||||
max_length=self.tokenizer.model_max_length,
|
||||
truncation=True,
|
||||
return_tensors="pt",
|
||||
)
|
||||
text_embeddings = self.text_encoder(text_input.input_ids.to(self.device))[0]
|
||||
else:
|
||||
batch_size = text_embeddings.shape[0]
|
||||
|
||||
# here `guidance_scale` is defined analog to the guidance weight `w` of equation (2)
|
||||
# of the Imagen paper: https://arxiv.org/pdf/2205.11487.pdf . `guidance_scale = 1`
|
||||
# corresponds to doing no classifier free guidance.
|
||||
do_classifier_free_guidance = guidance_scale > 1.0
|
||||
# get unconditional embeddings for classifier free guidance
|
||||
if do_classifier_free_guidance:
|
||||
# max_length = text_input.input_ids.shape[-1]
|
||||
max_length = 77 # self.tokenizer.model_max_length
|
||||
uncond_input = self.tokenizer(
|
||||
[""] * batch_size,
|
||||
padding="max_length",
|
||||
max_length=max_length,
|
||||
return_tensors="pt",
|
||||
)
|
||||
uncond_embeddings = self.text_encoder(
|
||||
uncond_input.input_ids.to(self.device)
|
||||
)[0]
|
||||
|
||||
# For classifier free guidance, we need to do two forward passes.
|
||||
# Here we concatenate the unconditional and text embeddings into a single batch
|
||||
# to avoid doing two forward passes
|
||||
text_embeddings = torch.cat([uncond_embeddings, text_embeddings])
|
||||
|
||||
# get the initial random noise unless the user supplied it
|
||||
latents_shape = (batch_size, self.unet.in_channels, height // 8, width // 8)
|
||||
if latents is None:
|
||||
latents = torch.randn(
|
||||
latents_shape,
|
||||
generator=generator,
|
||||
device=self.device,
|
||||
)
|
||||
else:
|
||||
if latents.shape != latents_shape:
|
||||
raise ValueError(
|
||||
f"Unexpected latents shape, got {latents.shape}, expected"
|
||||
f" {latents_shape}"
|
||||
)
|
||||
latents = latents.to(self.device)
|
||||
|
||||
# set timesteps
|
||||
accepts_offset = "offset" in set(
|
||||
inspect.signature(self.scheduler.set_timesteps).parameters.keys()
|
||||
)
|
||||
extra_set_kwargs = {}
|
||||
if accepts_offset:
|
||||
extra_set_kwargs["offset"] = 1
|
||||
|
||||
self.scheduler.set_timesteps(num_inference_steps, **extra_set_kwargs)
|
||||
|
||||
# if we use LMSDiscreteScheduler, let's make sure latents are mulitplied by sigmas
|
||||
if isinstance(self.scheduler, LMSDiscreteScheduler):
|
||||
latents = latents * self.scheduler.sigmas[0]
|
||||
|
||||
# prepare extra kwargs for the scheduler step, since not all schedulers have the same signature
|
||||
# eta (η) is only used with the DDIMScheduler, it will be ignored for other schedulers.
|
||||
# eta corresponds to η in DDIM paper: https://arxiv.org/abs/2010.02502
|
||||
# and should be between [0, 1]
|
||||
accepts_eta = "eta" in set(
|
||||
inspect.signature(self.scheduler.step).parameters.keys()
|
||||
)
|
||||
extra_step_kwargs = {}
|
||||
if accepts_eta:
|
||||
extra_step_kwargs["eta"] = eta
|
||||
|
||||
for i, t in enumerate(self.progress_bar(self.scheduler.timesteps)):
|
||||
# expand the latents if we are doing classifier free guidance
|
||||
latent_model_input = (
|
||||
torch.cat([latents] * 2) if do_classifier_free_guidance else latents
|
||||
)
|
||||
if isinstance(self.scheduler, LMSDiscreteScheduler):
|
||||
sigma = self.scheduler.sigmas[i]
|
||||
# the model input needs to be scaled to match the continuous ODE formulation in K-LMS
|
||||
latent_model_input = latent_model_input / ((sigma**2 + 1) ** 0.5)
|
||||
|
||||
# predict the noise residual
|
||||
noise_pred = self.unet(
|
||||
latent_model_input, t, encoder_hidden_states=text_embeddings
|
||||
)["sample"]
|
||||
|
||||
# perform guidance
|
||||
if do_classifier_free_guidance:
|
||||
noise_pred_uncond, noise_pred_text = noise_pred.chunk(2)
|
||||
noise_pred = noise_pred_uncond + guidance_scale * (
|
||||
noise_pred_text - noise_pred_uncond
|
||||
)
|
||||
|
||||
# compute the previous noisy sample x_t -> x_t-1
|
||||
if isinstance(self.scheduler, LMSDiscreteScheduler):
|
||||
latents = self.scheduler.step(
|
||||
noise_pred, i, latents, **extra_step_kwargs
|
||||
)["prev_sample"]
|
||||
else:
|
||||
latents = self.scheduler.step(
|
||||
noise_pred, t, latents, **extra_step_kwargs
|
||||
)["prev_sample"]
|
||||
|
||||
# scale and decode the image latents with vae
|
||||
latents = 1 / 0.18215 * latents
|
||||
image = self.vae.decode(latents).sample
|
||||
|
||||
image = (image / 2 + 0.5).clamp(0, 1)
|
||||
image = image.cpu().permute(0, 2, 3, 1).numpy()
|
||||
|
||||
safety_cheker_input = self.feature_extractor(
|
||||
self.numpy_to_pil(image), return_tensors="pt"
|
||||
).to(self.device)
|
||||
image, has_nsfw_concept = self.safety_checker(
|
||||
images=image, clip_input=safety_cheker_input.pixel_values
|
||||
)
|
||||
|
||||
if output_type == "pil":
|
||||
image = self.numpy_to_pil(image)
|
||||
|
||||
return {"sample": image, "nsfw_content_detected": has_nsfw_concept}
|
||||
|
||||
def embed_text(self, text):
|
||||
"""Helper to embed some text"""
|
||||
with torch.autocast("cuda"):
|
||||
text_input = self.tokenizer(
|
||||
text,
|
||||
padding="max_length",
|
||||
max_length=self.tokenizer.model_max_length,
|
||||
truncation=True,
|
||||
return_tensors="pt",
|
||||
)
|
||||
with torch.no_grad():
|
||||
embed = self.text_encoder(text_input.input_ids.to(self.device))[0]
|
||||
return embed
|
||||
|
||||
|
||||
class NoCheck(ModelMixin):
|
||||
"""Can be used in place of safety checker. Use responsibly and at your own risk."""
|
||||
def __init__(self):
|
||||
super().__init__()
|
||||
self.register_parameter(name='asdf', param=torch.nn.Parameter(torch.randn(3)))
|
||||
|
||||
def forward(self, images=None, **kwargs):
|
||||
return images, [False]
|
218
scripts/stable_diffusion_walk.py
Normal file
218
scripts/stable_diffusion_walk.py
Normal file
@ -0,0 +1,218 @@
|
||||
import json
|
||||
import subprocess
|
||||
from pathlib import Path
|
||||
|
||||
import numpy as np
|
||||
import torch
|
||||
from diffusers.schedulers import (DDIMScheduler, LMSDiscreteScheduler,
|
||||
PNDMScheduler)
|
||||
from diffusers import ModelMixin
|
||||
|
||||
from stable_diffusion_pipeline import StableDiffusionPipeline
|
||||
|
||||
pipeline = StableDiffusionPipeline.from_pretrained(
|
||||
"CompVis/stable-diffusion-v1-4",
|
||||
use_auth_token=True,
|
||||
torch_dtype=torch.float16,
|
||||
revision="fp16",
|
||||
).to("cuda")
|
||||
|
||||
default_scheduler = PNDMScheduler(
|
||||
beta_start=0.00085, beta_end=0.012, beta_schedule="scaled_linear"
|
||||
)
|
||||
ddim_scheduler = DDIMScheduler(
|
||||
beta_start=0.00085,
|
||||
beta_end=0.012,
|
||||
beta_schedule="scaled_linear",
|
||||
clip_sample=False,
|
||||
set_alpha_to_one=False,
|
||||
)
|
||||
klms_scheduler = LMSDiscreteScheduler(
|
||||
beta_start=0.00085, beta_end=0.012, beta_schedule="scaled_linear"
|
||||
)
|
||||
SCHEDULERS = dict(default=default_scheduler, ddim=ddim_scheduler, klms=klms_scheduler)
|
||||
|
||||
|
||||
def slerp(t, v0, v1, DOT_THRESHOLD=0.9995):
|
||||
"""helper function to spherically interpolate two arrays v1 v2"""
|
||||
|
||||
if not isinstance(v0, np.ndarray):
|
||||
inputs_are_torch = True
|
||||
input_device = v0.device
|
||||
v0 = v0.cpu().numpy()
|
||||
v1 = v1.cpu().numpy()
|
||||
|
||||
dot = np.sum(v0 * v1 / (np.linalg.norm(v0) * np.linalg.norm(v1)))
|
||||
if np.abs(dot) > DOT_THRESHOLD:
|
||||
v2 = (1 - t) * v0 + t * v1
|
||||
else:
|
||||
theta_0 = np.arccos(dot)
|
||||
sin_theta_0 = np.sin(theta_0)
|
||||
theta_t = theta_0 * t
|
||||
sin_theta_t = np.sin(theta_t)
|
||||
s0 = np.sin(theta_0 - theta_t) / sin_theta_0
|
||||
s1 = sin_theta_t / sin_theta_0
|
||||
v2 = s0 * v0 + s1 * v1
|
||||
|
||||
if inputs_are_torch:
|
||||
v2 = torch.from_numpy(v2).to(input_device)
|
||||
|
||||
return v2
|
||||
|
||||
|
||||
def make_video_ffmpeg(frame_dir, output_file_name='output.mp4', frame_filename="frame%06d.jpg", fps=30):
|
||||
frame_ref_path = str(frame_dir / frame_filename)
|
||||
video_path = str(frame_dir / output_file_name)
|
||||
subprocess.call(
|
||||
f"ffmpeg -r {fps} -i {frame_ref_path} -vcodec libx264 -crf 10 -pix_fmt yuv420p"
|
||||
f" {video_path}".split()
|
||||
)
|
||||
return video_path
|
||||
|
||||
|
||||
def walk(
|
||||
prompts=["blueberry spaghetti", "strawberry spaghetti"],
|
||||
seeds=[42, 123],
|
||||
num_steps=5,
|
||||
output_dir="dreams",
|
||||
name="berry_good_spaghetti",
|
||||
height=512,
|
||||
width=512,
|
||||
guidance_scale=7.5,
|
||||
eta=0.0,
|
||||
num_inference_steps=50,
|
||||
do_loop=False,
|
||||
make_video=False,
|
||||
use_lerp_for_text=False,
|
||||
scheduler="klms", # choices: default, ddim, klms
|
||||
disable_tqdm=False,
|
||||
upsample=False,
|
||||
fps=30,
|
||||
):
|
||||
"""Generate video frames/a video given a list of prompts and seeds.
|
||||
|
||||
Args:
|
||||
prompts (List[str], optional): List of . Defaults to ["blueberry spaghetti", "strawberry spaghetti"].
|
||||
seeds (List[int], optional): List of random seeds corresponding to given prompts.
|
||||
num_steps (int, optional): Number of steps to walk. Increase this value to 60-200 for good results. Defaults to 5.
|
||||
output_dir (str, optional): Root dir where images will be saved. Defaults to "dreams".
|
||||
name (str, optional): Sub directory of output_dir to save this run's files. Defaults to "berry_good_spaghetti".
|
||||
height (int, optional): Height of image to generate. Defaults to 512.
|
||||
width (int, optional): Width of image to generate. Defaults to 512.
|
||||
guidance_scale (float, optional): Higher = more adherance to prompt. Lower = let model take the wheel. Defaults to 7.5.
|
||||
eta (float, optional): ETA. Defaults to 0.0.
|
||||
num_inference_steps (int, optional): Number of diffusion steps. Defaults to 50.
|
||||
do_loop (bool, optional): Whether to loop from last prompt back to first. Defaults to False.
|
||||
make_video (bool, optional): Whether to make a video or just save the images. Defaults to False.
|
||||
use_lerp_for_text (bool, optional): Use LERP instead of SLERP for text embeddings when walking. Defaults to False.
|
||||
scheduler (str, optional): Which scheduler to use. Defaults to "klms". Choices are "default", "ddim", "klms".
|
||||
disable_tqdm (bool, optional): Whether to turn off the tqdm progress bars. Defaults to False.
|
||||
upsample (bool, optional): If True, uses Real-ESRGAN to upsample images 4x. Requires it to be installed
|
||||
which you can do by running: `pip install git+https://github.com/xinntao/Real-ESRGAN.git`. Defaults to False.
|
||||
fps (int, optional): The frames per second (fps) that you want the video to use. Does nothing if make_video is False. Defaults to 30.
|
||||
|
||||
Returns:
|
||||
str: Path to video file saved if make_video=True, else None.
|
||||
"""
|
||||
if upsample:
|
||||
from .upsampling import PipelineRealESRGAN
|
||||
|
||||
upsampling_pipeline = PipelineRealESRGAN.from_pretrained('nateraw/real-esrgan')
|
||||
|
||||
pipeline.set_progress_bar_config(disable=disable_tqdm)
|
||||
|
||||
pipeline.scheduler = SCHEDULERS[scheduler]
|
||||
|
||||
output_path = Path(output_dir) / name
|
||||
output_path.mkdir(exist_ok=True, parents=True)
|
||||
|
||||
# Write prompt info to file in output dir so we can keep track of what we did
|
||||
prompt_config_path = output_path / 'prompt_config.json'
|
||||
prompt_config_path.write_text(
|
||||
json.dumps(
|
||||
dict(
|
||||
prompts=prompts,
|
||||
seeds=seeds,
|
||||
num_steps=num_steps,
|
||||
name=name,
|
||||
guidance_scale=guidance_scale,
|
||||
eta=eta,
|
||||
num_inference_steps=num_inference_steps,
|
||||
do_loop=do_loop,
|
||||
make_video=make_video,
|
||||
use_lerp_for_text=use_lerp_for_text,
|
||||
scheduler=scheduler
|
||||
),
|
||||
indent=2,
|
||||
sort_keys=False,
|
||||
)
|
||||
)
|
||||
|
||||
assert len(prompts) == len(seeds)
|
||||
|
||||
first_prompt, *prompts = prompts
|
||||
embeds_a = pipeline.embed_text(first_prompt)
|
||||
|
||||
first_seed, *seeds = seeds
|
||||
latents_a = torch.randn(
|
||||
(1, pipeline.unet.in_channels, height // 8, width // 8),
|
||||
device=pipeline.device,
|
||||
generator=torch.Generator(device=pipeline.device).manual_seed(first_seed),
|
||||
)
|
||||
|
||||
if do_loop:
|
||||
prompts.append(first_prompt)
|
||||
seeds.append(first_seed)
|
||||
|
||||
frame_index = 0
|
||||
for prompt, seed in zip(prompts, seeds):
|
||||
# Text
|
||||
embeds_b = pipeline.embed_text(prompt)
|
||||
|
||||
# Latent Noise
|
||||
latents_b = torch.randn(
|
||||
(1, pipeline.unet.in_channels, height // 8, width // 8),
|
||||
device=pipeline.device,
|
||||
generator=torch.Generator(device=pipeline.device).manual_seed(seed),
|
||||
)
|
||||
|
||||
for i, t in enumerate(np.linspace(0, 1, num_steps)):
|
||||
do_print_progress = (i == 0) or ((frame_index + 1) % 20 == 0)
|
||||
if do_print_progress:
|
||||
print(f"COUNT: {frame_index+1}/{len(seeds)*num_steps}")
|
||||
|
||||
if use_lerp_for_text:
|
||||
embeds = torch.lerp(embeds_a, embeds_b, float(t))
|
||||
else:
|
||||
embeds = slerp(float(t), embeds_a, embeds_b)
|
||||
latents = slerp(float(t), latents_a, latents_b)
|
||||
|
||||
with torch.autocast("cuda"):
|
||||
im = pipeline(
|
||||
latents=latents,
|
||||
text_embeddings=embeds,
|
||||
height=height,
|
||||
width=width,
|
||||
guidance_scale=guidance_scale,
|
||||
eta=eta,
|
||||
num_inference_steps=num_inference_steps,
|
||||
output_type='pil' if not upsample else 'numpy'
|
||||
)["sample"][0]
|
||||
|
||||
if upsample:
|
||||
im = upsampling_pipeline(im)
|
||||
|
||||
im.save(output_path / ("frame%06d.jpg" % frame_index))
|
||||
frame_index += 1
|
||||
|
||||
embeds_a = embeds_b
|
||||
latents_a = latents_b
|
||||
|
||||
if make_video:
|
||||
return make_video_ffmpeg(output_path, f"{name}.mp4", fps=fps)
|
||||
|
||||
|
||||
if __name__ == "__main__":
|
||||
import fire
|
||||
|
||||
fire.Fire(walk)
|
57
scripts/textual_inversion.py
Normal file
57
scripts/textual_inversion.py
Normal file
@ -0,0 +1,57 @@
|
||||
# base webui import and utils.
|
||||
from webui_streamlit import st
|
||||
from sd_utils import *
|
||||
|
||||
# streamlit imports
|
||||
|
||||
|
||||
#other imports
|
||||
#from transformers import CLIPTextModel, CLIPTokenizer
|
||||
|
||||
# Temp imports
|
||||
|
||||
|
||||
# end of imports
|
||||
#---------------------------------------------------------------------------------------------------------------
|
||||
|
||||
#def load_learned_embed_in_clip(learned_embeds_path, text_encoder, tokenizer, token=None):
|
||||
|
||||
#loaded_learned_embeds = torch.load(learned_embeds_path, map_location="cpu")
|
||||
|
||||
## separate token and the embeds
|
||||
#print (loaded_learned_embeds)
|
||||
#trained_token = list(loaded_learned_embeds.keys())[0]
|
||||
#embeds = loaded_learned_embeds[trained_token]
|
||||
|
||||
## cast to dtype of text_encoder
|
||||
#dtype = text_encoder.get_input_embeddings().weight.dtype
|
||||
#embeds.to(dtype)
|
||||
|
||||
## add the token in tokenizer
|
||||
#token = token if token is not None else trained_token
|
||||
#num_added_tokens = tokenizer.add_tokens(token)
|
||||
#i = 1
|
||||
#while(num_added_tokens == 0):
|
||||
#print(f"The tokenizer already contains the token {token}.")
|
||||
#token = f"{token[:-1]}-{i}>"
|
||||
#print(f"Attempting to add the token {token}.")
|
||||
#num_added_tokens = tokenizer.add_tokens(token)
|
||||
#i+=1
|
||||
|
||||
## resize the token embeddings
|
||||
#text_encoder.resize_token_embeddings(len(tokenizer))
|
||||
|
||||
## get the id for the token and assign the embeds
|
||||
#token_id = tokenizer.convert_tokens_to_ids(token)
|
||||
#text_encoder.get_input_embeddings().weight.data[token_id] = embeds
|
||||
#return token
|
||||
|
||||
##def token_loader()
|
||||
#learned_token = load_learned_embed_in_clip(f"models/custom/embeddings/Custom Ami.pt", st.session_state.pipe.text_encoder, st.session_state.pipe.tokenizer, "*")
|
||||
#model_content["token"] = learned_token
|
||||
#models.append(model_content)
|
||||
|
||||
model_id = "./models/custom/embeddings/"
|
||||
|
||||
def layout():
|
||||
st.write("Textual Inversion")
|
368
scripts/txt2img.py
Normal file
368
scripts/txt2img.py
Normal file
@ -0,0 +1,368 @@
|
||||
# base webui import and utils.
|
||||
from webui_streamlit import st
|
||||
from sd_utils import *
|
||||
|
||||
# streamlit imports
|
||||
from streamlit import StopException
|
||||
from streamlit.runtime.in_memory_file_manager import in_memory_file_manager
|
||||
from streamlit.elements import image as STImage
|
||||
|
||||
#other imports
|
||||
import os
|
||||
from typing import Union
|
||||
from io import BytesIO
|
||||
from ldm.models.diffusion.ddim import DDIMSampler
|
||||
from ldm.models.diffusion.plms import PLMSSampler
|
||||
|
||||
# Temp imports
|
||||
|
||||
|
||||
# end of imports
|
||||
#---------------------------------------------------------------------------------------------------------------
|
||||
|
||||
|
||||
try:
|
||||
# this silences the annoying "Some weights of the model checkpoint were not used when initializing..." message at start.
|
||||
from transformers import logging
|
||||
|
||||
logging.set_verbosity_error()
|
||||
except:
|
||||
pass
|
||||
|
||||
class plugin_info():
|
||||
plugname = "txt2img"
|
||||
description = "Text to Image"
|
||||
isTab = True
|
||||
displayPriority = 1
|
||||
|
||||
|
||||
if os.path.exists(os.path.join(st.session_state['defaults'].general.GFPGAN_dir, "experiments", "pretrained_models", "GFPGANv1.3.pth")):
|
||||
GFPGAN_available = True
|
||||
else:
|
||||
GFPGAN_available = False
|
||||
|
||||
if os.path.exists(os.path.join(st.session_state['defaults'].general.RealESRGAN_dir, "experiments","pretrained_models", f"{st.session_state['defaults'].general.RealESRGAN_model}.pth")):
|
||||
RealESRGAN_available = True
|
||||
else:
|
||||
RealESRGAN_available = False
|
||||
|
||||
#
|
||||
def txt2img(prompt: str, ddim_steps: int, sampler_name: str, realesrgan_model_name: str,
|
||||
n_iter: int, batch_size: int, cfg_scale: float, seed: Union[int, str, None],
|
||||
height: int, width: int, separate_prompts:bool = False, normalize_prompt_weights:bool = True,
|
||||
save_individual_images: bool = True, save_grid: bool = True, group_by_prompt: bool = True,
|
||||
save_as_jpg: bool = True, use_GFPGAN: bool = True, use_RealESRGAN: bool = True,
|
||||
RealESRGAN_model: str = "RealESRGAN_x4plus_anime_6B", fp = None, variant_amount: float = None,
|
||||
variant_seed: int = None, ddim_eta:float = 0.0, write_info_files:bool = True):
|
||||
|
||||
outpath = st.session_state['defaults'].general.outdir_txt2img or st.session_state['defaults'].general.outdir or "outputs/txt2img-samples"
|
||||
|
||||
seed = seed_to_int(seed)
|
||||
|
||||
#prompt_matrix = 0 in toggles
|
||||
#normalize_prompt_weights = 1 in toggles
|
||||
#skip_save = 2 not in toggles
|
||||
#save_grid = 3 not in toggles
|
||||
#sort_samples = 4 in toggles
|
||||
#write_info_files = 5 in toggles
|
||||
#jpg_sample = 6 in toggles
|
||||
#use_GFPGAN = 7 in toggles
|
||||
#use_RealESRGAN = 8 in toggles
|
||||
|
||||
if sampler_name == 'PLMS':
|
||||
sampler = PLMSSampler(st.session_state["model"])
|
||||
elif sampler_name == 'DDIM':
|
||||
sampler = DDIMSampler(st.session_state["model"])
|
||||
elif sampler_name == 'k_dpm_2_a':
|
||||
sampler = KDiffusionSampler(st.session_state["model"],'dpm_2_ancestral')
|
||||
elif sampler_name == 'k_dpm_2':
|
||||
sampler = KDiffusionSampler(st.session_state["model"],'dpm_2')
|
||||
elif sampler_name == 'k_euler_a':
|
||||
sampler = KDiffusionSampler(st.session_state["model"],'euler_ancestral')
|
||||
elif sampler_name == 'k_euler':
|
||||
sampler = KDiffusionSampler(st.session_state["model"],'euler')
|
||||
elif sampler_name == 'k_heun':
|
||||
sampler = KDiffusionSampler(st.session_state["model"],'heun')
|
||||
elif sampler_name == 'k_lms':
|
||||
sampler = KDiffusionSampler(st.session_state["model"],'lms')
|
||||
else:
|
||||
raise Exception("Unknown sampler: " + sampler_name)
|
||||
|
||||
def init():
|
||||
pass
|
||||
|
||||
def sample(init_data, x, conditioning, unconditional_conditioning, sampler_name):
|
||||
samples_ddim, _ = sampler.sample(S=ddim_steps, conditioning=conditioning, batch_size=int(x.shape[0]), shape=x[0].shape, verbose=False, unconditional_guidance_scale=cfg_scale,
|
||||
unconditional_conditioning=unconditional_conditioning, eta=ddim_eta, x_T=x, img_callback=generation_callback,
|
||||
log_every_t=int(st.session_state.update_preview_frequency))
|
||||
|
||||
return samples_ddim
|
||||
|
||||
#try:
|
||||
output_images, seed, info, stats = process_images(
|
||||
outpath=outpath,
|
||||
func_init=init,
|
||||
func_sample=sample,
|
||||
prompt=prompt,
|
||||
seed=seed,
|
||||
sampler_name=sampler_name,
|
||||
save_grid=save_grid,
|
||||
batch_size=batch_size,
|
||||
n_iter=n_iter,
|
||||
steps=ddim_steps,
|
||||
cfg_scale=cfg_scale,
|
||||
width=width,
|
||||
height=height,
|
||||
prompt_matrix=separate_prompts,
|
||||
use_GFPGAN=st.session_state["use_GFPGAN"],
|
||||
use_RealESRGAN=st.session_state["use_RealESRGAN"],
|
||||
realesrgan_model_name=realesrgan_model_name,
|
||||
ddim_eta=ddim_eta,
|
||||
normalize_prompt_weights=normalize_prompt_weights,
|
||||
save_individual_images=save_individual_images,
|
||||
sort_samples=group_by_prompt,
|
||||
write_info_files=write_info_files,
|
||||
jpg_sample=save_as_jpg,
|
||||
variant_amount=variant_amount,
|
||||
variant_seed=variant_seed,
|
||||
)
|
||||
|
||||
del sampler
|
||||
|
||||
return output_images, seed, info, stats
|
||||
|
||||
#except RuntimeError as e:
|
||||
#err = e
|
||||
#err_msg = f'CRASHED:<br><textarea rows="5" style="color:white;background: black;width: -webkit-fill-available;font-family: monospace;font-size: small;font-weight: bold;">{str(e)}</textarea><br><br>Please wait while the program restarts.'
|
||||
#stats = err_msg
|
||||
#return [], seed, 'err', stats
|
||||
|
||||
def layout():
|
||||
with st.form("txt2img-inputs"):
|
||||
st.session_state["generation_mode"] = "txt2img"
|
||||
|
||||
input_col1, generate_col1 = st.columns([10,1])
|
||||
|
||||
with input_col1:
|
||||
#prompt = st.text_area("Input Text","")
|
||||
prompt = st.text_input("Input Text","", placeholder="A corgi wearing a top hat as an oil painting.")
|
||||
|
||||
# Every form must have a submit button, the extra blank spaces is a temp way to align it with the input field. Needs to be done in CSS or some other way.
|
||||
generate_col1.write("")
|
||||
generate_col1.write("")
|
||||
generate_button = generate_col1.form_submit_button("Generate")
|
||||
|
||||
# creating the page layout using columns
|
||||
col1, col2, col3 = st.columns([1,2,1], gap="large")
|
||||
|
||||
with col1:
|
||||
width = st.slider("Width:", min_value=64, max_value=4096, value=st.session_state['defaults'].txt2img.width, step=64)
|
||||
height = st.slider("Height:", min_value=64, max_value=4096, value=st.session_state['defaults'].txt2img.height, step=64)
|
||||
cfg_scale = st.slider("CFG (Classifier Free Guidance Scale):", min_value=1.0, max_value=30.0, value=st.session_state['defaults'].txt2img.cfg_scale, step=0.5, help="How strongly the image should follow the prompt.")
|
||||
seed = st.text_input("Seed:", value=st.session_state['defaults'].txt2img.seed, help=" The seed to use, if left blank a random seed will be generated.")
|
||||
batch_count = st.slider("Batch count.", min_value=1, max_value=100, value=st.session_state['defaults'].txt2img.batch_count, step=1, help="How many iterations or batches of images to generate in total.")
|
||||
|
||||
bs_slider_max_value = 5
|
||||
if st.session_state.defaults.general.optimized:
|
||||
bs_slider_max_value = 100
|
||||
|
||||
batch_size = st.slider(
|
||||
"Batch size",
|
||||
min_value=1,
|
||||
max_value=bs_slider_max_value,
|
||||
value=st.session_state.defaults.txt2img.batch_size,
|
||||
step=1,
|
||||
help="How many images are at once in a batch.\
|
||||
It increases the VRAM usage a lot but if you have enough VRAM it can reduce the time it takes to finish generation as more images are generated at once.\
|
||||
Default: 1")
|
||||
|
||||
with st.expander("Preview Settings"):
|
||||
st.session_state["update_preview"] = st.checkbox("Update Image Preview", value=st.session_state['defaults'].txt2img.update_preview,
|
||||
help="If enabled the image preview will be updated during the generation instead of at the end. \
|
||||
You can use the Update Preview \Frequency option bellow to customize how frequent it's updated. \
|
||||
By default this is enabled and the frequency is set to 1 step.")
|
||||
|
||||
st.session_state["update_preview_frequency"] = st.text_input("Update Image Preview Frequency", value=st.session_state['defaults'].txt2img.update_preview_frequency,
|
||||
help="Frequency in steps at which the the preview image is updated. By default the frequency \
|
||||
is set to 1 step.")
|
||||
|
||||
with col2:
|
||||
preview_tab, gallery_tab = st.tabs(["Preview", "Gallery"])
|
||||
|
||||
with preview_tab:
|
||||
#st.write("Image")
|
||||
#Image for testing
|
||||
#image = Image.open(requests.get("https://icon-library.com/images/image-placeholder-icon/image-placeholder-icon-13.jpg", stream=True).raw).convert('RGB')
|
||||
#new_image = image.resize((175, 240))
|
||||
#preview_image = st.image(image)
|
||||
|
||||
# create an empty container for the image, progress bar, etc so we can update it later and use session_state to hold them globally.
|
||||
st.session_state["preview_image"] = st.empty()
|
||||
|
||||
st.session_state["loading"] = st.empty()
|
||||
|
||||
st.session_state["progress_bar_text"] = st.empty()
|
||||
st.session_state["progress_bar"] = st.empty()
|
||||
|
||||
message = st.empty()
|
||||
|
||||
with col3:
|
||||
# If we have custom models available on the "models/custom"
|
||||
#folder then we show a menu to select which model we want to use, otherwise we use the main model for SD
|
||||
if st.session_state.CustomModel_available:
|
||||
st.session_state.custom_model = st.selectbox("Custom Model:", st.session_state.custom_models,
|
||||
index=st.session_state["custom_models"].index(st.session_state['defaults'].general.default_model),
|
||||
help="Select the model you want to use. This option is only available if you have custom models \
|
||||
on your 'models/custom' folder. The model name that will be shown here is the same as the name\
|
||||
the file for the model has on said folder, it is recommended to give the .ckpt file a name that \
|
||||
will make it easier for you to distinguish it from other models. Default: Stable Diffusion v1.4")
|
||||
|
||||
st.session_state.sampling_steps = st.slider("Sampling Steps",
|
||||
value=st.session_state['defaults'].txt2img.sampling_steps,
|
||||
min_value=st.session_state['defaults'].txt2img.slider_bounds.sampling.lower,
|
||||
max_value=st.session_state['defaults'].txt2img.slider_bounds.sampling.upper,
|
||||
step=st.session_state['defaults'].txt2img.slider_steps.sampling)
|
||||
|
||||
sampler_name_list = ["k_lms", "k_euler", "k_euler_a", "k_dpm_2", "k_dpm_2_a", "k_heun", "PLMS", "DDIM"]
|
||||
sampler_name = st.selectbox("Sampling method", sampler_name_list,
|
||||
index=sampler_name_list.index(st.session_state['defaults'].txt2img.default_sampler), help="Sampling method to use. Default: k_euler")
|
||||
|
||||
|
||||
|
||||
#basic_tab, advanced_tab = st.tabs(["Basic", "Advanced"])
|
||||
|
||||
#with basic_tab:
|
||||
#summit_on_enter = st.radio("Submit on enter?", ("Yes", "No"), horizontal=True,
|
||||
#help="Press the Enter key to summit, when 'No' is selected you can use the Enter key to write multiple lines.")
|
||||
|
||||
with st.expander("Advanced"):
|
||||
separate_prompts = st.checkbox("Create Prompt Matrix.", value=st.session_state['defaults'].txt2img.separate_prompts, help="Separate multiple prompts using the `|` character, and get all combinations of them.")
|
||||
normalize_prompt_weights = st.checkbox("Normalize Prompt Weights.", value=st.session_state['defaults'].txt2img.normalize_prompt_weights, help="Ensure the sum of all weights add up to 1.0")
|
||||
save_individual_images = st.checkbox("Save individual images.", value=st.session_state['defaults'].txt2img.save_individual_images, help="Save each image generated before any filter or enhancement is applied.")
|
||||
save_grid = st.checkbox("Save grid",value=st.session_state['defaults'].txt2img.save_grid, help="Save a grid with all the images generated into a single image.")
|
||||
group_by_prompt = st.checkbox("Group results by prompt", value=st.session_state['defaults'].txt2img.group_by_prompt,
|
||||
help="Saves all the images with the same prompt into the same folder. When using a prompt matrix each prompt combination will have its own folder.")
|
||||
write_info_files = st.checkbox("Write Info file", value=st.session_state['defaults'].txt2img.write_info_files, help="Save a file next to the image with informartion about the generation.")
|
||||
save_as_jpg = st.checkbox("Save samples as jpg", value=st.session_state['defaults'].txt2img.save_as_jpg, help="Saves the images as jpg instead of png.")
|
||||
|
||||
if st.session_state["GFPGAN_available"]:
|
||||
st.session_state["use_GFPGAN"] = st.checkbox("Use GFPGAN", value=st.session_state['defaults'].txt2img.use_GFPGAN, help="Uses the GFPGAN model to improve faces after the generation.\
|
||||
This greatly improve the quality and consistency of faces but uses extra VRAM. Disable if you need the extra VRAM.")
|
||||
else:
|
||||
st.session_state["use_GFPGAN"] = False
|
||||
|
||||
if st.session_state["RealESRGAN_available"]:
|
||||
st.session_state["use_RealESRGAN"] = st.checkbox("Use RealESRGAN", value=st.session_state['defaults'].txt2img.use_RealESRGAN,
|
||||
help="Uses the RealESRGAN model to upscale the images after the generation.\
|
||||
This greatly improve the quality and lets you have high resolution images but uses extra VRAM. Disable if you need the extra VRAM.")
|
||||
st.session_state["RealESRGAN_model"] = st.selectbox("RealESRGAN model", ["RealESRGAN_x4plus", "RealESRGAN_x4plus_anime_6B"], index=0)
|
||||
else:
|
||||
st.session_state["use_RealESRGAN"] = False
|
||||
st.session_state["RealESRGAN_model"] = "RealESRGAN_x4plus"
|
||||
|
||||
variant_amount = st.slider("Variant Amount:", value=st.session_state['defaults'].txt2img.variant_amount, min_value=0.0, max_value=1.0, step=0.01)
|
||||
variant_seed = st.text_input("Variant Seed:", value=st.session_state['defaults'].txt2img.seed, help="The seed to use when generating a variant, if left blank a random seed will be generated.")
|
||||
galleryCont = st.empty()
|
||||
|
||||
if generate_button:
|
||||
#print("Loading models")
|
||||
# load the models when we hit the generate button for the first time, it wont be loaded after that so dont worry.
|
||||
load_models(False, st.session_state["use_GFPGAN"], st.session_state["use_RealESRGAN"], st.session_state["RealESRGAN_model"], st.session_state["CustomModel_available"],
|
||||
st.session_state["custom_model"])
|
||||
|
||||
|
||||
try:
|
||||
#
|
||||
output_images, seeds, info, stats = txt2img(prompt, st.session_state.sampling_steps, sampler_name, st.session_state["RealESRGAN_model"], batch_count, batch_size,
|
||||
cfg_scale, seed, height, width, separate_prompts, normalize_prompt_weights, save_individual_images,
|
||||
save_grid, group_by_prompt, save_as_jpg, st.session_state["use_GFPGAN"], st.session_state["use_RealESRGAN"], st.session_state["RealESRGAN_model"],
|
||||
variant_amount=variant_amount, variant_seed=variant_seed, write_info_files=write_info_files)
|
||||
|
||||
message.success('Render Complete: ' + info + '; Stats: ' + stats, icon="✅")
|
||||
|
||||
#history_tab,col1,col2,col3,PlaceHolder,col1_cont,col2_cont,col3_cont = st.session_state['historyTab']
|
||||
|
||||
#if 'latestImages' in st.session_state:
|
||||
#for i in output_images:
|
||||
##push the new image to the list of latest images and remove the oldest one
|
||||
##remove the last index from the list\
|
||||
#st.session_state['latestImages'].pop()
|
||||
##add the new image to the start of the list
|
||||
#st.session_state['latestImages'].insert(0, i)
|
||||
#PlaceHolder.empty()
|
||||
#with PlaceHolder.container():
|
||||
#col1, col2, col3 = st.columns(3)
|
||||
#col1_cont = st.container()
|
||||
#col2_cont = st.container()
|
||||
#col3_cont = st.container()
|
||||
#images = st.session_state['latestImages']
|
||||
#with col1_cont:
|
||||
#with col1:
|
||||
#[st.image(images[index]) for index in [0, 3, 6] if index < len(images)]
|
||||
#with col2_cont:
|
||||
#with col2:
|
||||
#[st.image(images[index]) for index in [1, 4, 7] if index < len(images)]
|
||||
#with col3_cont:
|
||||
#with col3:
|
||||
#[st.image(images[index]) for index in [2, 5, 8] if index < len(images)]
|
||||
#historyGallery = st.empty()
|
||||
|
||||
## check if output_images length is the same as seeds length
|
||||
#with gallery_tab:
|
||||
#st.markdown(createHTMLGallery(output_images,seeds), unsafe_allow_html=True)
|
||||
|
||||
|
||||
#st.session_state['historyTab'] = [history_tab,col1,col2,col3,PlaceHolder,col1_cont,col2_cont,col3_cont]
|
||||
|
||||
except (StopException, KeyError):
|
||||
print(f"Received Streamlit StopException")
|
||||
|
||||
# this will render all the images at the end of the generation but its better if its moved to a second tab inside col2 and shown as a gallery.
|
||||
# use the current col2 first tab to show the preview_img and update it as its generated.
|
||||
#preview_image.image(output_images)
|
||||
|
||||
#on import run init
|
||||
def createHTMLGallery(images,info):
|
||||
html3 = """
|
||||
<div class="gallery-history" style="
|
||||
display: flex;
|
||||
flex-wrap: wrap;
|
||||
align-items: flex-start;">
|
||||
"""
|
||||
mkdwn_array = []
|
||||
for i in images:
|
||||
try:
|
||||
seed = info[images.index(i)]
|
||||
except:
|
||||
seed = ' '
|
||||
image_io = BytesIO()
|
||||
i.save(image_io, 'PNG')
|
||||
width, height = i.size
|
||||
#get random number for the id
|
||||
image_id = "%s" % (str(images.index(i)))
|
||||
(data, mimetype) = STImage._normalize_to_bytes(image_io.getvalue(), width, 'auto')
|
||||
this_file = in_memory_file_manager.add(data, mimetype, image_id)
|
||||
img_str = this_file.url
|
||||
#img_str = 'data:image/png;base64,' + b64encode(image_io.getvalue()).decode('ascii')
|
||||
#get image size
|
||||
|
||||
#make sure the image is not bigger then 150px but keep the aspect ratio
|
||||
if width > 150:
|
||||
height = int(height * (150/width))
|
||||
width = 150
|
||||
if height > 150:
|
||||
width = int(width * (150/height))
|
||||
height = 150
|
||||
|
||||
#mkdwn = f"""<img src="{img_str}" alt="Image" with="200" height="200" />"""
|
||||
mkdwn = f'''<div class="gallery" style="margin: 3px;" >
|
||||
<a href="{img_str}">
|
||||
<img src="{img_str}" alt="Image" width="{width}" height="{height}">
|
||||
</a>
|
||||
<div class="desc" style="text-align: center; opacity: 40%;">{seed}</div>
|
||||
</div>
|
||||
'''
|
||||
mkdwn_array.append(mkdwn)
|
||||
html3 += "".join(mkdwn_array)
|
||||
html3 += '</div>'
|
||||
return html3
|
780
scripts/txt2vid.py
Normal file
780
scripts/txt2vid.py
Normal file
@ -0,0 +1,780 @@
|
||||
# base webui import and utils.
|
||||
from webui_streamlit import st
|
||||
from sd_utils import *
|
||||
|
||||
# streamlit imports
|
||||
from streamlit import StopException
|
||||
from streamlit.runtime.in_memory_file_manager import in_memory_file_manager
|
||||
from streamlit.elements import image as STImage
|
||||
|
||||
#other imports
|
||||
|
||||
import os
|
||||
from PIL import Image
|
||||
import torch
|
||||
import numpy as np
|
||||
import time, inspect, timeit
|
||||
import torch
|
||||
from torch import autocast
|
||||
from io import BytesIO
|
||||
import imageio
|
||||
from slugify import slugify
|
||||
|
||||
# Temp imports
|
||||
|
||||
# these are for testing txt2vid, should be removed and we should use things from our own code.
|
||||
from diffusers import StableDiffusionPipeline
|
||||
from diffusers.schedulers import DDIMScheduler, LMSDiscreteScheduler, PNDMScheduler
|
||||
|
||||
# end of imports
|
||||
#---------------------------------------------------------------------------------------------------------------
|
||||
|
||||
try:
|
||||
# this silences the annoying "Some weights of the model checkpoint were not used when initializing..." message at start.
|
||||
from transformers import logging
|
||||
|
||||
logging.set_verbosity_error()
|
||||
except:
|
||||
pass
|
||||
|
||||
class plugin_info():
|
||||
plugname = "txt2img"
|
||||
description = "Text to Image"
|
||||
isTab = True
|
||||
displayPriority = 1
|
||||
|
||||
|
||||
if os.path.exists(os.path.join(st.session_state['defaults'].general.GFPGAN_dir, "experiments", "pretrained_models", "GFPGANv1.3.pth")):
|
||||
GFPGAN_available = True
|
||||
else:
|
||||
GFPGAN_available = False
|
||||
|
||||
if os.path.exists(os.path.join(st.session_state['defaults'].general.RealESRGAN_dir, "experiments","pretrained_models", f"{st.session_state['defaults'].txt2vid.RealESRGAN_model}.pth")):
|
||||
RealESRGAN_available = True
|
||||
else:
|
||||
RealESRGAN_available = False
|
||||
|
||||
#
|
||||
# -----------------------------------------------------------------------------
|
||||
|
||||
@torch.no_grad()
|
||||
def diffuse(
|
||||
pipe,
|
||||
cond_embeddings, # text conditioning, should be (1, 77, 768)
|
||||
cond_latents, # image conditioning, should be (1, 4, 64, 64)
|
||||
num_inference_steps,
|
||||
cfg_scale,
|
||||
eta,
|
||||
):
|
||||
|
||||
torch_device = cond_latents.get_device()
|
||||
|
||||
# classifier guidance: add the unconditional embedding
|
||||
max_length = cond_embeddings.shape[1] # 77
|
||||
uncond_input = pipe.tokenizer([""], padding="max_length", max_length=max_length, return_tensors="pt")
|
||||
uncond_embeddings = pipe.text_encoder(uncond_input.input_ids.to(torch_device))[0]
|
||||
text_embeddings = torch.cat([uncond_embeddings, cond_embeddings])
|
||||
|
||||
# if we use LMSDiscreteScheduler, let's make sure latents are mulitplied by sigmas
|
||||
if isinstance(pipe.scheduler, LMSDiscreteScheduler):
|
||||
cond_latents = cond_latents * pipe.scheduler.sigmas[0]
|
||||
|
||||
# init the scheduler
|
||||
accepts_offset = "offset" in set(inspect.signature(pipe.scheduler.set_timesteps).parameters.keys())
|
||||
extra_set_kwargs = {}
|
||||
if accepts_offset:
|
||||
extra_set_kwargs["offset"] = 1
|
||||
|
||||
pipe.scheduler.set_timesteps(num_inference_steps + st.session_state.sampling_steps, **extra_set_kwargs)
|
||||
# prepare extra kwargs for the scheduler step, since not all schedulers have the same signature
|
||||
# eta (η) is only used with the DDIMScheduler, it will be ignored for other schedulers.
|
||||
# eta corresponds to η in DDIM paper: https://arxiv.org/abs/2010.02502
|
||||
# and should be between [0, 1]
|
||||
accepts_eta = "eta" in set(inspect.signature(pipe.scheduler.step).parameters.keys())
|
||||
extra_step_kwargs = {}
|
||||
if accepts_eta:
|
||||
extra_step_kwargs["eta"] = eta
|
||||
|
||||
|
||||
step_counter = 0
|
||||
inference_counter = 0
|
||||
|
||||
if "current_chunk_speed" not in st.session_state:
|
||||
st.session_state["current_chunk_speed"] = 0
|
||||
|
||||
if "previous_chunk_speed_list" not in st.session_state:
|
||||
st.session_state["previous_chunk_speed_list"] = [0]
|
||||
st.session_state["previous_chunk_speed_list"].append(st.session_state["current_chunk_speed"])
|
||||
|
||||
if "update_preview_frequency_list" not in st.session_state:
|
||||
st.session_state["update_preview_frequency_list"] = [0]
|
||||
st.session_state["update_preview_frequency_list"].append(st.session_state['defaults'].txt2vid.update_preview_frequency)
|
||||
|
||||
|
||||
# diffuse!
|
||||
for i, t in enumerate(pipe.scheduler.timesteps):
|
||||
start = timeit.default_timer()
|
||||
|
||||
#status_text.text(f"Running step: {step_counter}{total_number_steps} {percent} | {duration:.2f}{speed}")
|
||||
|
||||
# expand the latents for classifier free guidance
|
||||
latent_model_input = torch.cat([cond_latents] * 2)
|
||||
if isinstance(pipe.scheduler, LMSDiscreteScheduler):
|
||||
sigma = pipe.scheduler.sigmas[i]
|
||||
latent_model_input = latent_model_input / ((sigma**2 + 1) ** 0.5)
|
||||
|
||||
# predict the noise residual
|
||||
noise_pred = pipe.unet(latent_model_input, t, encoder_hidden_states=text_embeddings)["sample"]
|
||||
|
||||
# cfg
|
||||
noise_pred_uncond, noise_pred_text = noise_pred.chunk(2)
|
||||
noise_pred = noise_pred_uncond + cfg_scale * (noise_pred_text - noise_pred_uncond)
|
||||
|
||||
# compute the previous noisy sample x_t -> x_t-1
|
||||
if isinstance(pipe.scheduler, LMSDiscreteScheduler):
|
||||
cond_latents = pipe.scheduler.step(noise_pred, i, cond_latents, **extra_step_kwargs)["prev_sample"]
|
||||
else:
|
||||
cond_latents = pipe.scheduler.step(noise_pred, t, cond_latents, **extra_step_kwargs)["prev_sample"]
|
||||
|
||||
#print (st.session_state["update_preview_frequency"])
|
||||
#update the preview image if it is enabled and the frequency matches the step_counter
|
||||
if st.session_state['defaults'].txt2vid.update_preview:
|
||||
step_counter += 1
|
||||
|
||||
if st.session_state['defaults'].txt2vid.update_preview_frequency == step_counter or step_counter == st.session_state.sampling_steps:
|
||||
if st.session_state.dynamic_preview_frequency:
|
||||
st.session_state["current_chunk_speed"], st.session_state["previous_chunk_speed_list"], st.session_state['defaults'].txt2vid.update_preview_frequency, st.session_state["avg_update_preview_frequency"] = optimize_update_preview_frequency(st.session_state["current_chunk_speed"], st.session_state["previous_chunk_speed_list"], st.session_state['defaults'].txt2vid.update_preview_frequency, st.session_state["update_preview_frequency_list"])
|
||||
|
||||
#scale and decode the image latents with vae
|
||||
cond_latents_2 = 1 / 0.18215 * cond_latents
|
||||
image = pipe.vae.decode(cond_latents_2)
|
||||
|
||||
# generate output numpy image as uint8
|
||||
image = torch.clamp((image["sample"] + 1.0) / 2.0, min=0.0, max=1.0)
|
||||
image = transforms.ToPILImage()(image.squeeze_(0))
|
||||
|
||||
st.session_state["preview_image"].image(image)
|
||||
|
||||
step_counter = 0
|
||||
|
||||
duration = timeit.default_timer() - start
|
||||
|
||||
st.session_state["current_chunk_speed"] = duration
|
||||
|
||||
if duration >= 1:
|
||||
speed = "s/it"
|
||||
else:
|
||||
speed = "it/s"
|
||||
duration = 1 / duration
|
||||
|
||||
if i > st.session_state.sampling_steps:
|
||||
inference_counter += 1
|
||||
inference_percent = int(100 * float(inference_counter + 1 if inference_counter < num_inference_steps else num_inference_steps)/float(num_inference_steps))
|
||||
inference_progress = f"{inference_counter + 1 if inference_counter < num_inference_steps else num_inference_steps}/{num_inference_steps} {inference_percent}% "
|
||||
else:
|
||||
inference_progress = ""
|
||||
|
||||
percent = int(100 * float(i+1 if i+1 < st.session_state.sampling_steps else st.session_state.sampling_steps)/float(st.session_state.sampling_steps))
|
||||
frames_percent = int(100 * float(st.session_state.current_frame if st.session_state.current_frame < st.session_state.max_frames else st.session_state.max_frames)/float(st.session_state.max_frames))
|
||||
|
||||
st.session_state["progress_bar_text"].text(
|
||||
f"Running step: {i+1 if i+1 < st.session_state.sampling_steps else st.session_state.sampling_steps}/{st.session_state.sampling_steps} "
|
||||
f"{percent if percent < 100 else 100}% {inference_progress}{duration:.2f}{speed} | "
|
||||
f"Frame: {st.session_state.current_frame + 1 if st.session_state.current_frame < st.session_state.max_frames else st.session_state.max_frames}/{st.session_state.max_frames} "
|
||||
f"{frames_percent if frames_percent < 100 else 100}% {st.session_state.frame_duration:.2f}{st.session_state.frame_speed}"
|
||||
)
|
||||
st.session_state["progress_bar"].progress(percent if percent < 100 else 100)
|
||||
|
||||
return image
|
||||
|
||||
#
|
||||
def txt2vid(
|
||||
# --------------------------------------
|
||||
# args you probably want to change
|
||||
prompts = ["blueberry spaghetti", "strawberry spaghetti"], # prompt to dream about
|
||||
gpu:int = st.session_state['defaults'].general.gpu, # id of the gpu to run on
|
||||
#name:str = 'test', # name of this project, for the output directory
|
||||
#rootdir:str = st.session_state['defaults'].general.outdir,
|
||||
num_steps:int = 200, # number of steps between each pair of sampled points
|
||||
max_frames:int = 10000, # number of frames to write and then exit the script
|
||||
num_inference_steps:int = 50, # more (e.g. 100, 200 etc) can create slightly better images
|
||||
cfg_scale:float = 5.0, # can depend on the prompt. usually somewhere between 3-10 is good
|
||||
do_loop = False,
|
||||
use_lerp_for_text = False,
|
||||
seeds = None,
|
||||
quality:int = 100, # for jpeg compression of the output images
|
||||
eta:float = 0.0,
|
||||
width:int = 256,
|
||||
height:int = 256,
|
||||
weights_path = "CompVis/stable-diffusion-v1-4",
|
||||
scheduler="klms", # choices: default, ddim, klms
|
||||
disable_tqdm = False,
|
||||
#-----------------------------------------------
|
||||
beta_start = 0.0001,
|
||||
beta_end = 0.00012,
|
||||
beta_schedule = "scaled_linear",
|
||||
starting_image=None
|
||||
):
|
||||
"""
|
||||
prompt = ["blueberry spaghetti", "strawberry spaghetti"], # prompt to dream about
|
||||
gpu:int = st.session_state['defaults'].general.gpu, # id of the gpu to run on
|
||||
#name:str = 'test', # name of this project, for the output directory
|
||||
#rootdir:str = st.session_state['defaults'].general.outdir,
|
||||
num_steps:int = 200, # number of steps between each pair of sampled points
|
||||
max_frames:int = 10000, # number of frames to write and then exit the script
|
||||
num_inference_steps:int = 50, # more (e.g. 100, 200 etc) can create slightly better images
|
||||
cfg_scale:float = 5.0, # can depend on the prompt. usually somewhere between 3-10 is good
|
||||
do_loop = False,
|
||||
use_lerp_for_text = False,
|
||||
seed = None,
|
||||
quality:int = 100, # for jpeg compression of the output images
|
||||
eta:float = 0.0,
|
||||
width:int = 256,
|
||||
height:int = 256,
|
||||
weights_path = "CompVis/stable-diffusion-v1-4",
|
||||
scheduler="klms", # choices: default, ddim, klms
|
||||
disable_tqdm = False,
|
||||
beta_start = 0.0001,
|
||||
beta_end = 0.00012,
|
||||
beta_schedule = "scaled_linear"
|
||||
"""
|
||||
mem_mon = MemUsageMonitor('MemMon')
|
||||
mem_mon.start()
|
||||
|
||||
|
||||
seeds = seed_to_int(seeds)
|
||||
|
||||
# We add an extra frame because most
|
||||
# of the time the first frame is just the noise.
|
||||
#max_frames +=1
|
||||
|
||||
assert torch.cuda.is_available()
|
||||
assert height % 8 == 0 and width % 8 == 0
|
||||
torch.manual_seed(seeds)
|
||||
torch_device = f"cuda:{gpu}"
|
||||
|
||||
# init the output dir
|
||||
sanitized_prompt = slugify(prompts)
|
||||
|
||||
full_path = os.path.join(os.getcwd(), st.session_state['defaults'].general.outdir, "txt2vid-samples", "samples", sanitized_prompt)
|
||||
|
||||
if len(full_path) > 220:
|
||||
sanitized_prompt = sanitized_prompt[:220-len(full_path)]
|
||||
full_path = os.path.join(os.getcwd(), st.session_state['defaults'].general.outdir, "txt2vid-samples", "samples", sanitized_prompt)
|
||||
|
||||
os.makedirs(full_path, exist_ok=True)
|
||||
|
||||
# Write prompt info to file in output dir so we can keep track of what we did
|
||||
if st.session_state.write_info_files:
|
||||
with open(os.path.join(full_path , f'{slugify(str(seeds))}_config.json' if len(prompts) > 1 else "prompts_config.json"), "w") as outfile:
|
||||
outfile.write(json.dumps(
|
||||
dict(
|
||||
prompts = prompts,
|
||||
gpu = gpu,
|
||||
num_steps = num_steps,
|
||||
max_frames = max_frames,
|
||||
num_inference_steps = num_inference_steps,
|
||||
cfg_scale = cfg_scale,
|
||||
do_loop = do_loop,
|
||||
use_lerp_for_text = use_lerp_for_text,
|
||||
seeds = seeds,
|
||||
quality = quality,
|
||||
eta = eta,
|
||||
width = width,
|
||||
height = height,
|
||||
weights_path = weights_path,
|
||||
scheduler=scheduler,
|
||||
disable_tqdm = disable_tqdm,
|
||||
beta_start = beta_start,
|
||||
beta_end = beta_end,
|
||||
beta_schedule = beta_schedule
|
||||
),
|
||||
indent=2,
|
||||
sort_keys=False,
|
||||
))
|
||||
|
||||
#print(scheduler)
|
||||
default_scheduler = PNDMScheduler(
|
||||
beta_start=beta_start, beta_end=beta_end, beta_schedule=beta_schedule
|
||||
)
|
||||
# ------------------------------------------------------------------------------
|
||||
#Schedulers
|
||||
ddim_scheduler = DDIMScheduler(
|
||||
beta_start=beta_start,
|
||||
beta_end=beta_end,
|
||||
beta_schedule=beta_schedule,
|
||||
clip_sample=False,
|
||||
set_alpha_to_one=False,
|
||||
)
|
||||
|
||||
klms_scheduler = LMSDiscreteScheduler(
|
||||
beta_start=beta_start, beta_end=beta_end, beta_schedule=beta_schedule
|
||||
)
|
||||
|
||||
SCHEDULERS = dict(default=default_scheduler, ddim=ddim_scheduler, klms=klms_scheduler)
|
||||
|
||||
# ------------------------------------------------------------------------------
|
||||
st.session_state["progress_bar_text"].text("Loading models...")
|
||||
|
||||
try:
|
||||
if "model" in st.session_state:
|
||||
del st.session_state["model"]
|
||||
except:
|
||||
pass
|
||||
|
||||
#print (st.session_state["weights_path"] != weights_path)
|
||||
|
||||
try:
|
||||
if not "pipe" in st.session_state or st.session_state["weights_path"] != weights_path:
|
||||
if st.session_state["weights_path"] != weights_path:
|
||||
del st.session_state["weights_path"]
|
||||
|
||||
st.session_state["weights_path"] = weights_path
|
||||
st.session_state["pipe"] = StableDiffusionPipeline.from_pretrained(
|
||||
weights_path,
|
||||
use_local_file=True,
|
||||
use_auth_token=True,
|
||||
torch_dtype=torch.float16 if st.session_state['defaults'].general.use_float16 else None,
|
||||
revision="fp16" if not st.session_state['defaults'].general.no_half else None
|
||||
)
|
||||
|
||||
st.session_state["pipe"].unet.to(torch_device)
|
||||
st.session_state["pipe"].vae.to(torch_device)
|
||||
st.session_state["pipe"].text_encoder.to(torch_device)
|
||||
|
||||
if st.session_state.defaults.general.enable_attention_slicing:
|
||||
st.session_state["pipe"].enable_attention_slicing()
|
||||
if st.session_state.defaults.general.enable_minimal_memory_usage:
|
||||
st.session_state["pipe"].enable_minimal_memory_usage()
|
||||
|
||||
print("Tx2Vid Model Loaded")
|
||||
else:
|
||||
print("Tx2Vid Model already Loaded")
|
||||
|
||||
except:
|
||||
#del st.session_state["weights_path"]
|
||||
#del st.session_state["pipe"]
|
||||
|
||||
st.session_state["weights_path"] = weights_path
|
||||
st.session_state["pipe"] = StableDiffusionPipeline.from_pretrained(
|
||||
weights_path,
|
||||
use_local_file=True,
|
||||
use_auth_token=True,
|
||||
torch_dtype=torch.float16 if st.session_state['defaults'].general.use_float16 else None,
|
||||
revision="fp16" if not st.session_state['defaults'].general.no_half else None
|
||||
)
|
||||
|
||||
st.session_state["pipe"].unet.to(torch_device)
|
||||
st.session_state["pipe"].vae.to(torch_device)
|
||||
st.session_state["pipe"].text_encoder.to(torch_device)
|
||||
|
||||
if st.session_state.defaults.general.enable_attention_slicing:
|
||||
st.session_state["pipe"].enable_attention_slicing()
|
||||
|
||||
|
||||
print("Tx2Vid Model Loaded")
|
||||
|
||||
st.session_state["pipe"].scheduler = SCHEDULERS[scheduler]
|
||||
|
||||
# get the conditional text embeddings based on the prompt
|
||||
text_input = st.session_state["pipe"].tokenizer(prompts, padding="max_length", max_length=st.session_state["pipe"].tokenizer.model_max_length, truncation=True, return_tensors="pt")
|
||||
cond_embeddings = st.session_state["pipe"].text_encoder(text_input.input_ids.to(torch_device))[0] # shape [1, 77, 768]
|
||||
|
||||
#
|
||||
if st.session_state.defaults.general.use_sd_concepts_library:
|
||||
|
||||
prompt_tokens = re.findall('<([a-zA-Z0-9-]+)>', prompts)
|
||||
|
||||
if prompt_tokens:
|
||||
# compviz
|
||||
#tokenizer = (st.session_state["model"] if not st.session_state['defaults'].general.optimized else st.session_state.modelCS).cond_stage_model.tokenizer
|
||||
#text_encoder = (st.session_state["model"] if not st.session_state['defaults'].general.optimized else st.session_state.modelCS).cond_stage_model.transformer
|
||||
|
||||
# diffusers
|
||||
tokenizer = st.session_state.pipe.tokenizer
|
||||
text_encoder = st.session_state.pipe.text_encoder
|
||||
|
||||
ext = ('pt', 'bin')
|
||||
#print (prompt_tokens)
|
||||
|
||||
if len(prompt_tokens) > 1:
|
||||
for token_name in prompt_tokens:
|
||||
embedding_path = os.path.join(st.session_state['defaults'].general.sd_concepts_library_folder, token_name)
|
||||
if os.path.exists(embedding_path):
|
||||
for files in os.listdir(embedding_path):
|
||||
if files.endswith(ext):
|
||||
load_learned_embed_in_clip(f"{os.path.join(embedding_path, files)}", text_encoder, tokenizer, f"<{token_name}>")
|
||||
else:
|
||||
embedding_path = os.path.join(st.session_state['defaults'].general.sd_concepts_library_folder, prompt_tokens[0])
|
||||
if os.path.exists(embedding_path):
|
||||
for files in os.listdir(embedding_path):
|
||||
if files.endswith(ext):
|
||||
load_learned_embed_in_clip(f"{os.path.join(embedding_path, files)}", text_encoder, tokenizer, f"<{prompt_tokens[0]}>")
|
||||
|
||||
# sample a source
|
||||
init1 = torch.randn((1, st.session_state["pipe"].unet.in_channels, height // 8, width // 8), device=torch_device)
|
||||
|
||||
if do_loop:
|
||||
prompts = [prompts, prompts]
|
||||
seeds = [seeds, seeds]
|
||||
#first_seed, *seeds = seeds
|
||||
#prompts.append(prompts)
|
||||
#seeds.append(first_seed)
|
||||
|
||||
|
||||
# iterate the loop
|
||||
frames = []
|
||||
frame_index = 0
|
||||
|
||||
st.session_state["total_frames_avg_duration"] = []
|
||||
st.session_state["total_frames_avg_speed"] = []
|
||||
|
||||
try:
|
||||
while frame_index < max_frames:
|
||||
st.session_state["frame_duration"] = 0
|
||||
st.session_state["frame_speed"] = 0
|
||||
st.session_state["current_frame"] = frame_index
|
||||
|
||||
# sample the destination
|
||||
init2 = torch.randn((1, st.session_state["pipe"].unet.in_channels, height // 8, width // 8), device=torch_device)
|
||||
|
||||
for i, t in enumerate(np.linspace(0, 1, max_frames)):
|
||||
start = timeit.default_timer()
|
||||
print(f"COUNT: {frame_index+1}/{max_frames}")
|
||||
|
||||
#if use_lerp_for_text:
|
||||
#init = torch.lerp(init1, init2, float(t))
|
||||
#else:
|
||||
#init = slerp(gpu, float(t), init1, init2)
|
||||
|
||||
init = slerp(gpu, float(t), init1, init2)
|
||||
|
||||
with autocast("cuda"):
|
||||
image = diffuse(st.session_state["pipe"], cond_embeddings, init, num_inference_steps, cfg_scale, eta)
|
||||
|
||||
#im = Image.fromarray(image)
|
||||
outpath = os.path.join(full_path, 'frame%06d.png' % frame_index)
|
||||
image.save(outpath, quality=quality)
|
||||
|
||||
# send the image to the UI to update it
|
||||
#st.session_state["preview_image"].image(im)
|
||||
|
||||
#append the frames to the frames list so we can use them later.
|
||||
frames.append(np.asarray(image))
|
||||
|
||||
#increase frame_index counter.
|
||||
frame_index += 1
|
||||
|
||||
st.session_state["current_frame"] = frame_index
|
||||
|
||||
duration = timeit.default_timer() - start
|
||||
|
||||
if duration >= 1:
|
||||
speed = "s/it"
|
||||
else:
|
||||
speed = "it/s"
|
||||
duration = 1 / duration
|
||||
|
||||
st.session_state["frame_duration"] = duration
|
||||
st.session_state["frame_speed"] = speed
|
||||
|
||||
init1 = init2
|
||||
|
||||
except StopException:
|
||||
pass
|
||||
|
||||
|
||||
if st.session_state['save_video']:
|
||||
# write video to memory
|
||||
#output = io.BytesIO()
|
||||
#writer = imageio.get_writer(os.path.join(os.getcwd(), st.session_state['defaults'].general.outdir, "txt2vid-samples"), im, extension=".mp4", fps=30)
|
||||
try:
|
||||
video_path = os.path.join(os.getcwd(), st.session_state['defaults'].general.outdir, "txt2vid-samples","temp.mp4")
|
||||
writer = imageio.get_writer(video_path, fps=24)
|
||||
for frame in frames:
|
||||
writer.append_data(frame)
|
||||
writer.close()
|
||||
except:
|
||||
print("Can't save video, skipping.")
|
||||
|
||||
# show video preview on the UI
|
||||
st.session_state["preview_video"].video(open(video_path, 'rb').read())
|
||||
|
||||
mem_max_used, mem_total = mem_mon.read_and_stop()
|
||||
time_diff = time.time()- start
|
||||
|
||||
info = f"""
|
||||
{prompts}
|
||||
Sampling Steps: {num_steps}, Sampler: {scheduler}, CFG scale: {cfg_scale}, Seed: {seeds}, Max Frames: {max_frames}""".strip()
|
||||
stats = f'''
|
||||
Took { round(time_diff, 2) }s total ({ round(time_diff/(max_frames),2) }s per image)
|
||||
Peak memory usage: { -(mem_max_used // -1_048_576) } MiB / { -(mem_total // -1_048_576) } MiB / { round(mem_max_used/mem_total*100, 3) }%'''
|
||||
|
||||
return video_path, seeds, info, stats
|
||||
|
||||
#on import run init
|
||||
def createHTMLGallery(images,info):
|
||||
html3 = """
|
||||
<div class="gallery-history" style="
|
||||
display: flex;
|
||||
flex-wrap: wrap;
|
||||
align-items: flex-start;">
|
||||
"""
|
||||
mkdwn_array = []
|
||||
for i in images:
|
||||
try:
|
||||
seed = info[images.index(i)]
|
||||
except:
|
||||
seed = ' '
|
||||
image_io = BytesIO()
|
||||
i.save(image_io, 'PNG')
|
||||
width, height = i.size
|
||||
#get random number for the id
|
||||
image_id = "%s" % (str(images.index(i)))
|
||||
(data, mimetype) = STImage._normalize_to_bytes(image_io.getvalue(), width, 'auto')
|
||||
this_file = in_memory_file_manager.add(data, mimetype, image_id)
|
||||
img_str = this_file.url
|
||||
#img_str = 'data:image/png;base64,' + b64encode(image_io.getvalue()).decode('ascii')
|
||||
#get image size
|
||||
|
||||
#make sure the image is not bigger then 150px but keep the aspect ratio
|
||||
if width > 150:
|
||||
height = int(height * (150/width))
|
||||
width = 150
|
||||
if height > 150:
|
||||
width = int(width * (150/height))
|
||||
height = 150
|
||||
|
||||
#mkdwn = f"""<img src="{img_str}" alt="Image" with="200" height="200" />"""
|
||||
mkdwn = f'''<div class="gallery" style="margin: 3px;" >
|
||||
<a href="{img_str}">
|
||||
<img src="{img_str}" alt="Image" width="{width}" height="{height}">
|
||||
</a>
|
||||
<div class="desc" style="text-align: center; opacity: 40%;">{seed}</div>
|
||||
</div>
|
||||
'''
|
||||
mkdwn_array.append(mkdwn)
|
||||
|
||||
html3 += "".join(mkdwn_array)
|
||||
html3 += '</div>'
|
||||
return html3
|
||||
#
|
||||
def layout():
|
||||
with st.form("txt2vid-inputs"):
|
||||
st.session_state["generation_mode"] = "txt2vid"
|
||||
|
||||
input_col1, generate_col1 = st.columns([10,1])
|
||||
with input_col1:
|
||||
#prompt = st.text_area("Input Text","")
|
||||
prompt = st.text_input("Input Text","", placeholder="A corgi wearing a top hat as an oil painting.")
|
||||
|
||||
# Every form must have a submit button, the extra blank spaces is a temp way to align it with the input field. Needs to be done in CSS or some other way.
|
||||
generate_col1.write("")
|
||||
generate_col1.write("")
|
||||
generate_button = generate_col1.form_submit_button("Generate")
|
||||
|
||||
# creating the page layout using columns
|
||||
col1, col2, col3 = st.columns([1,2,1], gap="large")
|
||||
|
||||
with col1:
|
||||
width = st.slider("Width:", min_value=64, max_value=2048, value=st.session_state['defaults'].txt2vid.width, step=64)
|
||||
height = st.slider("Height:", min_value=64, max_value=2048, value=st.session_state['defaults'].txt2vid.height, step=64)
|
||||
cfg_scale = st.slider("CFG (Classifier Free Guidance Scale):", min_value=1.0, max_value=30.0, value=st.session_state['defaults'].txt2vid.cfg_scale, step=0.5, help="How strongly the image should follow the prompt.")
|
||||
|
||||
#uploaded_images = st.file_uploader("Upload Image", accept_multiple_files=False, type=["png", "jpg", "jpeg", "webp"],
|
||||
#help="Upload an image which will be used for the image to image generation.")
|
||||
seed = st.text_input("Seed:", value=st.session_state['defaults'].txt2vid.seed, help=" The seed to use, if left blank a random seed will be generated.")
|
||||
#batch_count = st.slider("Batch count.", min_value=1, max_value=100, value=st.session_state['defaults'].txt2vid.batch_count, step=1, help="How many iterations or batches of images to generate in total.")
|
||||
#batch_size = st.slider("Batch size", min_value=1, max_value=250, value=st.session_state['defaults'].txt2vid.batch_size, step=1,
|
||||
#help="How many images are at once in a batch.\
|
||||
#It increases the VRAM usage a lot but if you have enough VRAM it can reduce the time it takes to finish generation as more images are generated at once.\
|
||||
#Default: 1")
|
||||
|
||||
st.session_state["max_frames"] = int(st.text_input("Max Frames:", value=st.session_state['defaults'].txt2vid.max_frames, help="Specify the max number of frames you want to generate."))
|
||||
|
||||
with st.expander("Preview Settings"):
|
||||
st.session_state["update_preview"] = st.checkbox("Update Image Preview", value=st.session_state['defaults'].txt2vid.update_preview,
|
||||
help="If enabled the image preview will be updated during the generation instead of at the end. \
|
||||
You can use the Update Preview \Frequency option bellow to customize how frequent it's updated. \
|
||||
By default this is enabled and the frequency is set to 1 step.")
|
||||
|
||||
st.session_state["update_preview_frequency"] = st.text_input("Update Image Preview Frequency", value=st.session_state['defaults'].txt2vid.update_preview_frequency,
|
||||
help="Frequency in steps at which the the preview image is updated. By default the frequency \
|
||||
is set to 1 step.")
|
||||
|
||||
#
|
||||
|
||||
|
||||
|
||||
with col2:
|
||||
preview_tab, gallery_tab = st.tabs(["Preview", "Gallery"])
|
||||
|
||||
with preview_tab:
|
||||
#st.write("Image")
|
||||
#Image for testing
|
||||
#image = Image.open(requests.get("https://icon-library.com/images/image-placeholder-icon/image-placeholder-icon-13.jpg", stream=True).raw).convert('RGB')
|
||||
#new_image = image.resize((175, 240))
|
||||
#preview_image = st.image(image)
|
||||
|
||||
# create an empty container for the image, progress bar, etc so we can update it later and use session_state to hold them globally.
|
||||
st.session_state["preview_image"] = st.empty()
|
||||
|
||||
st.session_state["loading"] = st.empty()
|
||||
|
||||
st.session_state["progress_bar_text"] = st.empty()
|
||||
st.session_state["progress_bar"] = st.empty()
|
||||
|
||||
#generate_video = st.empty()
|
||||
st.session_state["preview_video"] = st.empty()
|
||||
|
||||
message = st.empty()
|
||||
|
||||
with gallery_tab:
|
||||
st.write('Here should be the image gallery, if I could make a grid in streamlit.')
|
||||
|
||||
with col3:
|
||||
# If we have custom models available on the "models/custom"
|
||||
#folder then we show a menu to select which model we want to use, otherwise we use the main model for SD
|
||||
if st.session_state["CustomModel_available"]:
|
||||
custom_model = st.selectbox("Custom Model:", st.session_state["defaults"].txt2vid.custom_models_list,
|
||||
index=st.session_state["defaults"].txt2vid.custom_models_list.index(st.session_state["defaults"].txt2vid.default_model),
|
||||
help="Select the model you want to use. This option is only available if you have custom models \
|
||||
on your 'models/custom' folder. The model name that will be shown here is the same as the name\
|
||||
the file for the model has on said folder, it is recommended to give the .ckpt file a name that \
|
||||
will make it easier for you to distinguish it from other models. Default: Stable Diffusion v1.4")
|
||||
else:
|
||||
custom_model = "CompVis/stable-diffusion-v1-4"
|
||||
|
||||
#st.session_state["weights_path"] = custom_model
|
||||
#else:
|
||||
#custom_model = "CompVis/stable-diffusion-v1-4"
|
||||
#st.session_state["weights_path"] = f"CompVis/{slugify(custom_model.lower())}"
|
||||
|
||||
st.session_state.sampling_steps = st.slider("Sampling Steps",
|
||||
value=st.session_state['defaults'].txt2vid.sampling_steps,
|
||||
min_value=st.session_state['defaults'].txt2vid.slider_bounds.sampling.lower,
|
||||
max_value=st.session_state['defaults'].txt2vid.slider_bounds.sampling.upper,
|
||||
step=st.session_state['defaults'].txt2vid.slider_steps.sampling,
|
||||
help="Number of steps between each pair of sampled points")
|
||||
st.session_state.num_inference_steps = st.slider("Inference Steps:", value=st.session_state['defaults'].txt2vid.num_inference_steps, min_value=10,step=10, max_value=500,
|
||||
help="Higher values (e.g. 100, 200 etc) can create better images.")
|
||||
|
||||
#sampler_name_list = ["k_lms", "k_euler", "k_euler_a", "k_dpm_2", "k_dpm_2_a", "k_heun", "PLMS", "DDIM"]
|
||||
#sampler_name = st.selectbox("Sampling method", sampler_name_list,
|
||||
#index=sampler_name_list.index(st.session_state['defaults'].txt2vid.default_sampler), help="Sampling method to use. Default: k_euler")
|
||||
scheduler_name_list = ["klms", "ddim"]
|
||||
scheduler_name = st.selectbox("Scheduler:", scheduler_name_list,
|
||||
index=scheduler_name_list.index(st.session_state['defaults'].txt2vid.scheduler_name), help="Scheduler to use. Default: klms")
|
||||
|
||||
beta_scheduler_type_list = ["scaled_linear", "linear"]
|
||||
beta_scheduler_type = st.selectbox("Beta Schedule Type:", beta_scheduler_type_list,
|
||||
index=beta_scheduler_type_list.index(st.session_state['defaults'].txt2vid.beta_scheduler_type), help="Schedule Type to use. Default: linear")
|
||||
|
||||
|
||||
#basic_tab, advanced_tab = st.tabs(["Basic", "Advanced"])
|
||||
|
||||
#with basic_tab:
|
||||
#summit_on_enter = st.radio("Submit on enter?", ("Yes", "No"), horizontal=True,
|
||||
#help="Press the Enter key to summit, when 'No' is selected you can use the Enter key to write multiple lines.")
|
||||
|
||||
with st.expander("Advanced"):
|
||||
st.session_state["separate_prompts"] = st.checkbox("Create Prompt Matrix.", value=st.session_state['defaults'].txt2vid.separate_prompts,
|
||||
help="Separate multiple prompts using the `|` character, and get all combinations of them.")
|
||||
st.session_state["normalize_prompt_weights"] = st.checkbox("Normalize Prompt Weights.",
|
||||
value=st.session_state['defaults'].txt2vid.normalize_prompt_weights, help="Ensure the sum of all weights add up to 1.0")
|
||||
st.session_state["save_individual_images"] = st.checkbox("Save individual images.",
|
||||
value=st.session_state['defaults'].txt2vid.save_individual_images, help="Save each image generated before any filter or enhancement is applied.")
|
||||
st.session_state["save_video"] = st.checkbox("Save video",value=st.session_state['defaults'].txt2vid.save_video, help="Save a video with all the images generated as frames at the end of the generation.")
|
||||
st.session_state["group_by_prompt"] = st.checkbox("Group results by prompt", value=st.session_state['defaults'].txt2vid.group_by_prompt,
|
||||
help="Saves all the images with the same prompt into the same folder. When using a prompt matrix each prompt combination will have its own folder.")
|
||||
st.session_state["write_info_files"] = st.checkbox("Write Info file", value=st.session_state['defaults'].txt2vid.write_info_files,
|
||||
help="Save a file next to the image with informartion about the generation.")
|
||||
st.session_state["dynamic_preview_frequency"] = st.checkbox("Dynamic Preview Frequency", value=st.session_state['defaults'].txt2vid.dynamic_preview_frequency,
|
||||
help="This option tries to find the best value at which we can update \
|
||||
the preview image during generation while minimizing the impact it has in performance. Default: True")
|
||||
st.session_state["do_loop"] = st.checkbox("Do Loop", value=st.session_state['defaults'].txt2vid.do_loop,
|
||||
help="Do loop")
|
||||
st.session_state["save_as_jpg"] = st.checkbox("Save samples as jpg", value=st.session_state['defaults'].txt2vid.save_as_jpg, help="Saves the images as jpg instead of png.")
|
||||
|
||||
if GFPGAN_available:
|
||||
st.session_state["use_GFPGAN"] = st.checkbox("Use GFPGAN", value=st.session_state['defaults'].txt2vid.use_GFPGAN, help="Uses the GFPGAN model to improve faces after the generation. This greatly improve the quality and consistency of faces but uses extra VRAM. Disable if you need the extra VRAM.")
|
||||
else:
|
||||
st.session_state["use_GFPGAN"] = False
|
||||
|
||||
if RealESRGAN_available:
|
||||
st.session_state["use_RealESRGAN"] = st.checkbox("Use RealESRGAN", value=st.session_state['defaults'].txt2vid.use_RealESRGAN,
|
||||
help="Uses the RealESRGAN model to upscale the images after the generation. This greatly improve the quality and lets you have high resolution images but uses extra VRAM. Disable if you need the extra VRAM.")
|
||||
st.session_state["RealESRGAN_model"] = st.selectbox("RealESRGAN model", ["RealESRGAN_x4plus", "RealESRGAN_x4plus_anime_6B"], index=0)
|
||||
else:
|
||||
st.session_state["use_RealESRGAN"] = False
|
||||
st.session_state["RealESRGAN_model"] = "RealESRGAN_x4plus"
|
||||
|
||||
st.session_state["variant_amount"] = st.slider("Variant Amount:", value=st.session_state['defaults'].txt2vid.variant_amount, min_value=0.0, max_value=1.0, step=0.01)
|
||||
st.session_state["variant_seed"] = st.text_input("Variant Seed:", value=st.session_state['defaults'].txt2vid.seed, help="The seed to use when generating a variant, if left blank a random seed will be generated.")
|
||||
st.session_state["beta_start"] = st.slider("Beta Start:", value=st.session_state['defaults'].txt2vid.beta_start, min_value=0.0001, max_value=0.03, step=0.0001, format="%.4f")
|
||||
st.session_state["beta_end"] = st.slider("Beta End:", value=st.session_state['defaults'].txt2vid.beta_end, min_value=0.0001, max_value=0.03, step=0.0001, format="%.4f")
|
||||
|
||||
if generate_button:
|
||||
#print("Loading models")
|
||||
# load the models when we hit the generate button for the first time, it wont be loaded after that so dont worry.
|
||||
#load_models(False, False, False, st.session_state["RealESRGAN_model"], CustomModel_available=st.session_state["CustomModel_available"], custom_model=custom_model)
|
||||
|
||||
try:
|
||||
# run video generation
|
||||
video, seed, info, stats = txt2vid(prompts=prompt, gpu=st.session_state["defaults"].general.gpu,
|
||||
num_steps=st.session_state.sampling_steps, max_frames=int(st.session_state.max_frames),
|
||||
num_inference_steps=st.session_state.num_inference_steps,
|
||||
cfg_scale=cfg_scale,do_loop=st.session_state["do_loop"],
|
||||
seeds=seed, quality=100, eta=0.0, width=width,
|
||||
height=height, weights_path=custom_model, scheduler=scheduler_name,
|
||||
disable_tqdm=False, beta_start=st.session_state["beta_start"], beta_end=st.session_state["beta_end"],
|
||||
beta_schedule=beta_scheduler_type, starting_image=None)
|
||||
|
||||
#message.success('Done!', icon="✅")
|
||||
message.success('Render Complete: ' + info + '; Stats: ' + stats, icon="✅")
|
||||
|
||||
#history_tab,col1,col2,col3,PlaceHolder,col1_cont,col2_cont,col3_cont = st.session_state['historyTab']
|
||||
|
||||
#if 'latestVideos' in st.session_state:
|
||||
#for i in video:
|
||||
##push the new image to the list of latest images and remove the oldest one
|
||||
##remove the last index from the list\
|
||||
#st.session_state['latestVideos'].pop()
|
||||
##add the new image to the start of the list
|
||||
#st.session_state['latestVideos'].insert(0, i)
|
||||
#PlaceHolder.empty()
|
||||
|
||||
#with PlaceHolder.container():
|
||||
#col1, col2, col3 = st.columns(3)
|
||||
#col1_cont = st.container()
|
||||
#col2_cont = st.container()
|
||||
#col3_cont = st.container()
|
||||
|
||||
#with col1_cont:
|
||||
#with col1:
|
||||
#st.image(st.session_state['latestVideos'][0])
|
||||
#st.image(st.session_state['latestVideos'][3])
|
||||
#st.image(st.session_state['latestVideos'][6])
|
||||
#with col2_cont:
|
||||
#with col2:
|
||||
#st.image(st.session_state['latestVideos'][1])
|
||||
#st.image(st.session_state['latestVideos'][4])
|
||||
#st.image(st.session_state['latestVideos'][7])
|
||||
#with col3_cont:
|
||||
#with col3:
|
||||
#st.image(st.session_state['latestVideos'][2])
|
||||
#st.image(st.session_state['latestVideos'][5])
|
||||
#st.image(st.session_state['latestVideos'][8])
|
||||
#historyGallery = st.empty()
|
||||
|
||||
## check if output_images length is the same as seeds length
|
||||
#with gallery_tab:
|
||||
#st.markdown(createHTMLGallery(video,seed), unsafe_allow_html=True)
|
||||
|
||||
|
||||
#st.session_state['historyTab'] = [history_tab,col1,col2,col3,PlaceHolder,col1_cont,col2_cont,col3_cont]
|
||||
|
||||
except (StopException, KeyError):
|
||||
print(f"Received Streamlit StopException")
|
||||
|
||||
|
565
scripts/webui.py
565
scripts/webui.py
File diff suppressed because it is too large
Load Diff
File diff suppressed because it is too large
Load Diff
2738
scripts/webui_streamlit_old.py
Normal file
2738
scripts/webui_streamlit_old.py
Normal file
File diff suppressed because it is too large
Load Diff
2
setup.py
2
setup.py
@ -1,7 +1,7 @@
|
||||
from setuptools import setup, find_packages
|
||||
|
||||
setup(
|
||||
name='latent-diffusion',
|
||||
name='sd-webui',
|
||||
version='0.0.1',
|
||||
description='',
|
||||
packages=find_packages(),
|
||||
|
2
webui.sh
2
webui.sh
@ -37,7 +37,7 @@ if ! conda env list | grep ".*${ENV_NAME}.*" >/dev/null 2>&1; then
|
||||
ENV_UPDATED=1
|
||||
elif [[ ! -z $CONDA_FORCE_UPDATE && $CONDA_FORCE_UPDATE == "true" ]] || (( $ENV_MODIFIED > $ENV_MODIFIED_CACHED )); then
|
||||
echo "Updating conda env: ${ENV_NAME} ..."
|
||||
conda env update --file $ENV_FILE --prune
|
||||
PIP_EXISTS_ACTION=w conda env update --file $ENV_FILE --prune
|
||||
ENV_UPDATED=1
|
||||
fi
|
||||
|
||||
|
Loading…
Reference in New Issue
Block a user