2022-10-24 02:22:40 +03:00
# This file is part of stable-diffusion-webui (https://github.com/Sygil-Dev/sygil-webui/).
2022-09-26 16:02:48 +03:00
2022-10-24 03:17:50 +03:00
# Copyright 2022 Sygil-Dev team.
2022-09-26 16:02:48 +03:00
# This program is free software: you can redistribute it and/or modify
# it under the terms of the GNU Affero General Public License as published by
# the Free Software Foundation, either version 3 of the License, or
# (at your option) any later version.
# This program is distributed in the hope that it will be useful,
# but WITHOUT ANY WARRANTY; without even the implied warranty of
# MERCHANTABILITY or FITNESS FOR A PARTICULAR PURPOSE. See the
# GNU Affero General Public License for more details.
# You should have received a copy of the GNU Affero General Public License
2022-10-06 11:46:41 +03:00
# along with this program. If not, see <http://www.gnu.org/licenses/>.
2022-09-14 14:19:24 +03:00
# base webui import and utils.
2022-09-14 00:08:40 +03:00
from sd_utils import *
2022-09-14 14:19:24 +03:00
# streamlit imports
2022-09-14 00:08:40 +03:00
from streamlit import StopException
2022-09-14 14:19:24 +03:00
#other imports
import cv2
2022-09-14 00:08:40 +03:00
from PIL import Image , ImageOps
import torch
import k_diffusion as K
import numpy as np
import time
import torch
import skimage
from ldm . models . diffusion . ddim import DDIMSampler
from ldm . models . diffusion . plms import PLMSSampler
2022-10-19 17:52:08 +03:00
# streamlit components
from custom_components import key_phrase_suggestions
2022-10-06 11:46:41 +03:00
# Temp imports
2022-09-14 14:19:24 +03:00
# end of imports
#---------------------------------------------------------------------------------------------------------------
2022-10-19 17:52:08 +03:00
key_phrase_suggestions . init ( )
2022-09-14 00:08:40 +03:00
try :
# this silences the annoying "Some weights of the model checkpoint were not used when initializing..." message at start.
from transformers import logging
logging . set_verbosity_error ( )
except :
pass
2022-10-06 11:46:41 +03:00
def img2img ( prompt : str = ' ' , init_info : any = None , init_info_mask : any = None , mask_mode : int = 0 , mask_blur_strength : int = 3 ,
2022-09-14 00:08:40 +03:00
mask_restore : bool = False , ddim_steps : int = 50 , sampler_name : str = ' DDIM ' ,
n_iter : int = 1 , cfg_scale : float = 7.5 , denoising_strength : float = 0.8 ,
seed : int = - 1 , noise_mode : int = 0 , find_noise_steps : str = " " , height : int = 512 , width : int = 512 , resize_mode : int = 0 , fp = None ,
2022-10-14 22:09:47 +03:00
variant_amount : float = 0.0 , variant_seed : int = None , ddim_eta : float = 0.0 ,
2022-10-02 06:18:09 +03:00
write_info_files : bool = True , separate_prompts : bool = False , normalize_prompt_weights : bool = True ,
2022-09-14 00:08:40 +03:00
save_individual_images : bool = True , save_grid : bool = True , group_by_prompt : bool = True ,
2022-10-04 17:25:47 +03:00
save_as_jpg : bool = True , use_GFPGAN : bool = True , GFPGAN_model : str = ' GFPGANv1.4 ' ,
2022-10-02 06:18:09 +03:00
use_RealESRGAN : bool = True , RealESRGAN_model : str = " RealESRGAN_x4plus_anime_6B " ,
use_LDSR : bool = True , LDSR_model : str = " model " ,
loopback : bool = False ,
2022-09-14 00:08:40 +03:00
random_seed_loopback : bool = False
) :
2022-10-02 19:46:14 +03:00
outpath = st . session_state [ ' defaults ' ] . general . outdir_img2img
2022-09-14 00:08:40 +03:00
seed = seed_to_int ( seed )
batch_size = 1
if sampler_name == ' PLMS ' :
2022-09-25 09:23:55 +03:00
sampler = PLMSSampler ( server_state [ " model " ] )
2022-09-14 00:08:40 +03:00
elif sampler_name == ' DDIM ' :
2022-09-25 09:23:55 +03:00
sampler = DDIMSampler ( server_state [ " model " ] )
2022-09-14 00:08:40 +03:00
elif sampler_name == ' k_dpm_2_a ' :
2022-09-25 09:23:55 +03:00
sampler = KDiffusionSampler ( server_state [ " model " ] , ' dpm_2_ancestral ' )
2022-09-14 00:08:40 +03:00
elif sampler_name == ' k_dpm_2 ' :
2022-09-25 09:23:55 +03:00
sampler = KDiffusionSampler ( server_state [ " model " ] , ' dpm_2 ' )
2022-09-14 00:08:40 +03:00
elif sampler_name == ' k_euler_a ' :
2022-09-25 09:23:55 +03:00
sampler = KDiffusionSampler ( server_state [ " model " ] , ' euler_ancestral ' )
2022-09-14 00:08:40 +03:00
elif sampler_name == ' k_euler ' :
2022-09-25 09:23:55 +03:00
sampler = KDiffusionSampler ( server_state [ " model " ] , ' euler ' )
2022-09-14 00:08:40 +03:00
elif sampler_name == ' k_heun ' :
2022-09-25 09:23:55 +03:00
sampler = KDiffusionSampler ( server_state [ " model " ] , ' heun ' )
2022-09-14 00:08:40 +03:00
elif sampler_name == ' k_lms ' :
2022-09-25 09:23:55 +03:00
sampler = KDiffusionSampler ( server_state [ " model " ] , ' lms ' )
2022-09-14 00:08:40 +03:00
else :
raise Exception ( " Unknown sampler: " + sampler_name )
def process_init_mask ( init_mask : Image ) :
if init_mask . mode == " RGBA " :
init_mask = init_mask . convert ( ' RGBA ' )
background = Image . new ( ' RGBA ' , init_mask . size , ( 0 , 0 , 0 ) )
init_mask = Image . alpha_composite ( background , init_mask )
init_mask = init_mask . convert ( ' RGB ' )
return init_mask
init_img = init_info
init_mask = None
if mask_mode == 0 :
if init_info_mask :
init_mask = process_init_mask ( init_info_mask )
elif mask_mode == 1 :
if init_info_mask :
init_mask = process_init_mask ( init_info_mask )
init_mask = ImageOps . invert ( init_mask )
elif mask_mode == 2 :
init_img_transparency = init_img . split ( ) [ - 1 ] . convert ( ' L ' ) #.point(lambda x: 255 if x > 0 else 0, mode='1')
init_mask = init_img_transparency
init_mask = init_mask . convert ( " RGB " )
init_mask = resize_image ( resize_mode , init_mask , width , height )
init_mask = init_mask . convert ( " RGB " )
assert 0. < = denoising_strength < = 1. , ' can only work with strength in [0.0, 1.0] '
t_enc = int ( denoising_strength * ddim_steps )
if init_mask is not None and ( noise_mode == 2 or noise_mode == 3 ) and init_img is not None :
noise_q = 0.99
color_variation = 0.0
mask_blend_factor = 1.0
np_init = ( np . asarray ( init_img . convert ( " RGB " ) ) / 255.0 ) . astype ( np . float64 ) # annoyingly complex mask fixing
np_mask_rgb = 1. - ( np . asarray ( ImageOps . invert ( init_mask ) . convert ( " RGB " ) ) / 255.0 ) . astype ( np . float64 )
np_mask_rgb - = np . min ( np_mask_rgb )
np_mask_rgb / = np . max ( np_mask_rgb )
np_mask_rgb = 1. - np_mask_rgb
np_mask_rgb_hardened = 1. - ( np_mask_rgb < 0.99 ) . astype ( np . float64 )
blurred = skimage . filters . gaussian ( np_mask_rgb_hardened [ : ] , sigma = 16. , channel_axis = 2 , truncate = 32. )
blurred2 = skimage . filters . gaussian ( np_mask_rgb_hardened [ : ] , sigma = 16. , channel_axis = 2 , truncate = 32. )
#np_mask_rgb_dilated = np_mask_rgb + blurred # fixup mask todo: derive magic constants
#np_mask_rgb = np_mask_rgb + blurred
np_mask_rgb_dilated = np . clip ( ( np_mask_rgb + blurred2 ) * 0.7071 , 0. , 1. )
np_mask_rgb = np . clip ( ( np_mask_rgb + blurred ) * 0.7071 , 0. , 1. )
2022-09-14 14:19:24 +03:00
noise_rgb = get_matched_noise ( np_init , np_mask_rgb , noise_q , color_variation )
2022-09-14 00:08:40 +03:00
blend_mask_rgb = np . clip ( np_mask_rgb_dilated , 0. , 1. ) * * ( mask_blend_factor )
noised = noise_rgb [ : ]
blend_mask_rgb * * = ( 2. )
noised = np_init [ : ] * ( 1. - blend_mask_rgb ) + noised * blend_mask_rgb
np_mask_grey = np . sum ( np_mask_rgb , axis = 2 ) / 3.
ref_mask = np_mask_grey < 1e-3
all_mask = np . ones ( ( height , width ) , dtype = bool )
noised [ all_mask , : ] = skimage . exposure . match_histograms ( noised [ all_mask , : ] * * 1. , noised [ ref_mask , : ] , channel_axis = 1 )
init_img = Image . fromarray ( np . clip ( noised * 255. , 0. , 255. ) . astype ( np . uint8 ) , mode = " RGB " )
st . session_state [ " editor_image " ] . image ( init_img ) # debug
def init ( ) :
image = init_img . convert ( ' RGB ' )
image = np . array ( image ) . astype ( np . float32 ) / 255.0
image = image [ None ] . transpose ( 0 , 3 , 1 , 2 )
image = torch . from_numpy ( image )
mask_channel = None
if init_mask :
alpha = resize_image ( resize_mode , init_mask , width / / 8 , height / / 8 )
mask_channel = alpha . split ( ) [ - 1 ]
mask = None
if mask_channel is not None :
mask = np . array ( mask_channel ) . astype ( np . float32 ) / 255.0
mask = ( 1 - mask )
mask = np . tile ( mask , ( 4 , 1 , 1 ) )
mask = mask [ None ] . transpose ( 0 , 1 , 2 , 3 )
2022-09-25 09:23:55 +03:00
mask = torch . from_numpy ( mask ) . to ( server_state [ " device " ] )
2022-09-14 00:08:40 +03:00
2022-09-14 16:40:56 +03:00
if st . session_state [ ' defaults ' ] . general . optimized :
2022-09-25 09:23:55 +03:00
server_state [ " modelFS " ] . to ( server_state [ " device " ] )
2022-09-14 00:08:40 +03:00
init_image = 2. * image - 1.
2022-09-25 09:23:55 +03:00
init_image = init_image . to ( server_state [ " device " ] )
2022-10-04 17:25:47 +03:00
init_latent = ( server_state [ " model " ] if not st . session_state [ ' defaults ' ] . general . optimized else server_state [ " modelFS " ] ) . get_first_stage_encoding ( ( server_state [ " model " ] if not st . session_state [ ' defaults ' ] . general . optimized else server_state [ " modelFS " ] ) . encode_first_stage ( init_image ) ) # move to latent space
2022-09-14 00:08:40 +03:00
2022-09-14 16:40:56 +03:00
if st . session_state [ ' defaults ' ] . general . optimized :
2022-09-14 00:08:40 +03:00
mem = torch . cuda . memory_allocated ( ) / 1e6
2022-09-25 09:23:55 +03:00
server_state [ " modelFS " ] . to ( " cpu " )
2022-09-14 00:08:40 +03:00
while ( torch . cuda . memory_allocated ( ) / 1e6 > = mem ) :
time . sleep ( 1 )
return init_latent , mask ,
def sample ( init_data , x , conditioning , unconditional_conditioning , sampler_name ) :
t_enc_steps = t_enc
obliterate = False
if ddim_steps == t_enc_steps :
t_enc_steps = t_enc_steps - 1
obliterate = True
if sampler_name != ' DDIM ' :
x0 , z_mask = init_data
sigmas = sampler . model_wrap . get_sigmas ( ddim_steps )
noise = x * sigmas [ ddim_steps - t_enc_steps - 1 ]
xi = x0 + noise
# Obliterate masked image
if z_mask is not None and obliterate :
random = torch . randn ( z_mask . shape , device = xi . device )
xi = ( z_mask * noise ) + ( ( 1 - z_mask ) * xi )
sigma_sched = sigmas [ ddim_steps - t_enc_steps - 1 : ]
model_wrap_cfg = CFGMaskedDenoiser ( sampler . model_wrap )
samples_ddim = K . sampling . __dict__ [ f ' sample_ { sampler . get_sampler_name ( ) } ' ] ( model_wrap_cfg , xi , sigma_sched ,
extra_args = { ' cond ' : conditioning , ' uncond ' : unconditional_conditioning ,
' cond_scale ' : cfg_scale , ' mask ' : z_mask , ' x0 ' : x0 , ' xi ' : xi } , disable = False ,
2022-10-14 22:09:47 +03:00
callback = generation_callback if not server_state [ " bridge " ] else None )
2022-09-14 00:08:40 +03:00
else :
x0 , z_mask = init_data
sampler . make_schedule ( ddim_num_steps = ddim_steps , ddim_eta = 0.0 , verbose = False )
2022-09-25 09:23:55 +03:00
z_enc = sampler . stochastic_encode ( x0 , torch . tensor ( [ t_enc_steps ] * batch_size ) . to ( server_state [ " device " ] ) )
2022-09-14 00:08:40 +03:00
# Obliterate masked image
if z_mask is not None and obliterate :
random = torch . randn ( z_mask . shape , device = z_enc . device )
z_enc = ( z_mask * random ) + ( ( 1 - z_mask ) * z_enc )
# decode it
samples_ddim = sampler . decode ( z_enc , conditioning , t_enc_steps ,
unconditional_guidance_scale = cfg_scale ,
unconditional_conditioning = unconditional_conditioning ,
z_mask = z_mask , x0 = x0 )
return samples_ddim
if loopback :
output_images , info = None , None
history = [ ]
initial_seed = None
do_color_correction = False
try :
from skimage import exposure
do_color_correction = True
except :
2022-10-15 15:34:07 +03:00
logger . error ( " Install scikit-image to perform color correction on loopback " )
2022-09-14 00:08:40 +03:00
for i in range ( n_iter ) :
if do_color_correction and i == 0 :
correction_target = cv2 . cvtColor ( np . asarray ( init_img . copy ( ) ) , cv2 . COLOR_RGB2LAB )
2022-09-25 04:01:25 +03:00
# RealESRGAN can only run on the final iteration
is_final_iteration = i == n_iter - 1
2022-09-14 00:08:40 +03:00
output_images , seed , info , stats = process_images (
outpath = outpath ,
func_init = init ,
func_sample = sample ,
prompt = prompt ,
seed = seed ,
sampler_name = sampler_name ,
save_grid = save_grid ,
batch_size = 1 ,
n_iter = 1 ,
steps = ddim_steps ,
cfg_scale = cfg_scale ,
width = width ,
height = height ,
prompt_matrix = separate_prompts ,
use_GFPGAN = use_GFPGAN ,
2022-10-02 01:40:50 +03:00
GFPGAN_model = GFPGAN_model ,
2022-09-25 04:01:25 +03:00
use_RealESRGAN = use_RealESRGAN and is_final_iteration , # Forcefully disable upscaling when using loopback
2022-09-14 00:08:40 +03:00
realesrgan_model_name = RealESRGAN_model ,
2022-10-02 06:18:09 +03:00
use_LDSR = use_LDSR ,
LDSR_model_name = LDSR_model ,
2022-09-14 00:08:40 +03:00
normalize_prompt_weights = normalize_prompt_weights ,
save_individual_images = save_individual_images ,
init_img = init_img ,
init_mask = init_mask ,
mask_blur_strength = mask_blur_strength ,
mask_restore = mask_restore ,
denoising_strength = denoising_strength ,
noise_mode = noise_mode ,
find_noise_steps = find_noise_steps ,
resize_mode = resize_mode ,
uses_loopback = loopback ,
uses_random_seed_loopback = random_seed_loopback ,
sort_samples = group_by_prompt ,
write_info_files = write_info_files ,
jpg_sample = save_as_jpg
)
if initial_seed is None :
initial_seed = seed
2022-09-14 15:48:21 +03:00
input_image = init_img
2022-09-14 00:08:40 +03:00
init_img = output_images [ 0 ]
if do_color_correction and correction_target is not None :
init_img = Image . fromarray ( cv2 . cvtColor ( exposure . match_histograms (
cv2 . cvtColor (
np . asarray ( init_img ) ,
cv2 . COLOR_RGB2LAB
) ,
correction_target ,
channel_axis = 2
) , cv2 . COLOR_LAB2RGB ) . astype ( " uint8 " ) )
2022-09-16 23:56:09 +03:00
if mask_restore is True and init_mask is not None :
2022-09-14 15:48:21 +03:00
color_mask = init_mask . filter ( ImageFilter . GaussianBlur ( mask_blur_strength ) )
color_mask = color_mask . convert ( ' L ' )
source_image = input_image . convert ( ' RGB ' )
target_image = init_img . convert ( ' RGB ' )
init_img = Image . composite ( source_image , target_image , color_mask )
2022-09-14 00:08:40 +03:00
if not random_seed_loopback :
seed = seed + 1
else :
seed = seed_to_int ( None )
denoising_strength = max ( denoising_strength * 0.95 , 0.1 )
history . append ( init_img )
output_images = history
seed = initial_seed
else :
output_images , seed , info , stats = process_images (
outpath = outpath ,
func_init = init ,
func_sample = sample ,
prompt = prompt ,
seed = seed ,
sampler_name = sampler_name ,
save_grid = save_grid ,
batch_size = batch_size ,
n_iter = n_iter ,
steps = ddim_steps ,
cfg_scale = cfg_scale ,
width = width ,
height = height ,
prompt_matrix = separate_prompts ,
use_GFPGAN = use_GFPGAN ,
2022-10-02 01:40:50 +03:00
GFPGAN_model = GFPGAN_model ,
2022-09-14 00:08:40 +03:00
use_RealESRGAN = use_RealESRGAN ,
realesrgan_model_name = RealESRGAN_model ,
2022-10-02 06:18:09 +03:00
use_LDSR = use_LDSR ,
LDSR_model_name = LDSR_model ,
2022-09-14 00:08:40 +03:00
normalize_prompt_weights = normalize_prompt_weights ,
save_individual_images = save_individual_images ,
init_img = init_img ,
init_mask = init_mask ,
mask_blur_strength = mask_blur_strength ,
denoising_strength = denoising_strength ,
noise_mode = noise_mode ,
find_noise_steps = find_noise_steps ,
mask_restore = mask_restore ,
resize_mode = resize_mode ,
uses_loopback = loopback ,
sort_samples = group_by_prompt ,
write_info_files = write_info_files ,
jpg_sample = save_as_jpg
)
del sampler
return output_images , seed , info , stats
2022-10-06 11:46:41 +03:00
2022-09-14 00:08:40 +03:00
#
def layout ( ) :
with st . form ( " img2img-inputs " ) :
st . session_state [ " generation_mode " ] = " img2img "
2022-10-06 11:46:41 +03:00
2022-09-14 00:08:40 +03:00
img2img_input_col , img2img_generate_col = st . columns ( [ 10 , 1 ] )
with img2img_input_col :
#prompt = st.text_area("Input Text","")
2022-10-19 17:52:08 +03:00
placeholder = " A corgi wearing a top hat as an oil painting. "
2022-10-20 23:19:26 +03:00
prompt = st . text_area ( " Input Text " , " " , placeholder = placeholder , height = 54 )
2022-10-19 17:52:08 +03:00
key_phrase_suggestions . suggestion_area ( placeholder )
2022-10-06 11:46:41 +03:00
2022-09-14 00:08:40 +03:00
# Every form must have a submit button, the extra blank spaces is a temp way to align it with the input field. Needs to be done in CSS or some other way.
img2img_generate_col . write ( " " )
img2img_generate_col . write ( " " )
generate_button = img2img_generate_col . form_submit_button ( " Generate " )
2022-10-06 11:46:41 +03:00
2022-09-14 00:08:40 +03:00
# creating the page layout using columns
2022-10-20 17:48:39 +03:00
col1_img2img_layout , col2_img2img_layout , col3_img2img_layout = st . columns ( [ 1 , 2 , 2 ] , gap = " medium " )
2022-10-06 11:46:41 +03:00
2022-09-14 00:08:40 +03:00
with col1_img2img_layout :
2022-10-06 11:46:41 +03:00
# If we have custom models available on the "models/custom"
2022-09-14 00:08:40 +03:00
#folder then we show a menu to select which model we want to use, otherwise we use the main model for SD
2022-09-25 01:27:09 +03:00
custom_models_available ( )
2022-09-29 13:29:44 +03:00
if server_state [ " CustomModel_available " ] :
2022-09-25 09:23:55 +03:00
st . session_state [ " custom_model " ] = st . selectbox ( " Custom Model: " , server_state [ " custom_models " ] ,
index = server_state [ " custom_models " ] . index ( st . session_state [ ' defaults ' ] . general . default_model ) ,
2022-09-14 00:08:40 +03:00
help = " Select the model you want to use. This option is only available if you have custom models \
on your ' models/custom ' folder . The model name that will be shown here is the same as the name \
the file for the model has on said folder , it is recommended to give the . ckpt file a name that \
2022-10-21 05:50:40 +03:00
will make it easier for you to distinguish it from other models . Default : Stable Diffusion v1 .5 " )
2022-09-14 00:08:40 +03:00
else :
2022-10-21 05:50:40 +03:00
st . session_state [ " custom_model " ] = " Stable Diffusion v1.5 "
2022-10-06 11:46:41 +03:00
st . session_state [ " sampling_steps " ] = st . number_input ( " Sampling Steps " , value = st . session_state [ ' defaults ' ] . img2img . sampling_steps . value ,
min_value = st . session_state [ ' defaults ' ] . img2img . sampling_steps . min_value ,
step = st . session_state [ ' defaults ' ] . img2img . sampling_steps . step )
2022-09-14 00:08:40 +03:00
sampler_name_list = [ " k_lms " , " k_euler " , " k_euler_a " , " k_dpm_2 " , " k_dpm_2_a " , " k_heun " , " PLMS " , " DDIM " ]
2022-10-06 11:46:41 +03:00
st . session_state [ " sampler_name " ] = st . selectbox ( " Sampling method " , sampler_name_list ,
index = sampler_name_list . index ( st . session_state [ ' defaults ' ] . img2img . sampler_name ) , help = " Sampling method to use. " )
2022-09-18 21:11:23 +03:00
width = st . slider ( " Width: " , min_value = st . session_state [ ' defaults ' ] . img2img . width . min_value , max_value = st . session_state [ ' defaults ' ] . img2img . width . max_value ,
2022-09-26 14:53:34 +03:00
value = st . session_state [ ' defaults ' ] . img2img . width . value , step = st . session_state [ ' defaults ' ] . img2img . width . step )
2022-09-18 21:11:23 +03:00
height = st . slider ( " Height: " , min_value = st . session_state [ ' defaults ' ] . img2img . height . min_value , max_value = st . session_state [ ' defaults ' ] . img2img . height . max_value ,
2022-09-26 14:53:34 +03:00
value = st . session_state [ ' defaults ' ] . img2img . height . value , step = st . session_state [ ' defaults ' ] . img2img . height . step )
2022-10-06 11:46:41 +03:00
seed = st . text_input ( " Seed: " , value = st . session_state [ ' defaults ' ] . img2img . seed , help = " The seed to use, if left blank a random seed will be generated. " )
2022-10-18 04:01:19 +03:00
cfg_scale = st . number_input ( " CFG (Classifier Free Guidance Scale): " , min_value = st . session_state [ ' defaults ' ] . img2img . cfg_scale . min_value ,
2022-10-19 16:49:03 +03:00
value = st . session_state [ ' defaults ' ] . img2img . cfg_scale . value ,
2022-10-18 04:01:19 +03:00
step = st . session_state [ ' defaults ' ] . img2img . cfg_scale . step ,
help = " How strongly the image should follow the prompt. " )
2022-10-06 11:46:41 +03:00
st . session_state [ " denoising_strength " ] = st . slider ( " Denoising Strength: " , value = st . session_state [ ' defaults ' ] . img2img . denoising_strength . value ,
2022-10-18 04:01:19 +03:00
min_value = st . session_state [ ' defaults ' ] . img2img . denoising_strength . min_value ,
max_value = st . session_state [ ' defaults ' ] . img2img . denoising_strength . max_value ,
step = st . session_state [ ' defaults ' ] . img2img . denoising_strength . step )
2022-10-06 11:46:41 +03:00
2022-09-26 14:53:34 +03:00
mask_expander = st . empty ( )
with mask_expander . expander ( " Mask " ) :
mask_mode_list = [ " Mask " , " Inverted mask " , " Image alpha " ]
mask_mode = st . selectbox ( " Mask Mode " , mask_mode_list ,
2022-10-18 04:01:19 +03:00
help = " Select how you want your image to be masked. \" Mask \" modifies the image where the mask is white. \n \
\" Inverted mask \" modifies the image where the mask is black. \" Image alpha \" modifies the image where the image is transparent. "
)
2022-09-26 14:53:34 +03:00
mask_mode = mask_mode_list . index ( mask_mode )
2022-10-06 11:46:41 +03:00
2022-09-26 14:53:34 +03:00
noise_mode_list = [ " Seed " , " Find Noise " , " Matched Noise " , " Find+Matched Noise " ]
noise_mode = st . selectbox (
" Noise Mode " , noise_mode_list ,
help = " "
)
noise_mode = noise_mode_list . index ( noise_mode )
2022-10-18 04:01:19 +03:00
find_noise_steps = st . number_input ( " Find Noise Steps " , value = st . session_state [ ' defaults ' ] . img2img . find_noise_steps . value ,
min_value = st . session_state [ ' defaults ' ] . img2img . find_noise_steps . min_value ,
2022-09-26 14:53:34 +03:00
step = st . session_state [ ' defaults ' ] . img2img . find_noise_steps . step )
2022-10-06 11:46:41 +03:00
2022-09-26 14:53:34 +03:00
with st . expander ( " Batch Options " ) :
2022-10-08 03:57:06 +03:00
st . session_state [ " batch_count " ] = st . number_input ( " Batch count. " , value = st . session_state [ ' defaults ' ] . img2img . batch_count . value ,
2022-10-18 04:01:19 +03:00
help = " How many iterations or batches of images to generate in total. " )
2022-10-02 23:10:17 +03:00
2022-10-08 03:57:06 +03:00
st . session_state [ " batch_size " ] = st . number_input ( " Batch size " , value = st . session_state . defaults . img2img . batch_size . value ,
2022-10-18 04:01:19 +03:00
help = " How many images are at once in a batch. \
It increases the VRAM usage a lot but if you have enough VRAM it can reduce the time it takes to finish generation as more images are generated at once . \
Default : 1 " )
2022-10-06 11:46:41 +03:00
2022-09-26 14:53:34 +03:00
with st . expander ( " Preview Settings " ) :
2022-10-02 23:10:17 +03:00
st . session_state [ " update_preview " ] = st . session_state [ " defaults " ] . general . update_preview
2022-10-08 02:59:29 +03:00
st . session_state [ " update_preview_frequency " ] = st . number_input ( " Update Image Preview Frequency " ,
2022-10-18 04:01:19 +03:00
min_value = 1 ,
value = st . session_state [ ' defaults ' ] . img2img . update_preview_frequency ,
help = " Frequency in steps at which the the preview image is updated. By default the frequency \
is set to 1 step . " )
2022-10-06 11:46:41 +03:00
#
2022-09-14 00:08:40 +03:00
with st . expander ( " Advanced " ) :
2022-10-02 23:10:17 +03:00
with st . expander ( " Output Settings " ) :
separate_prompts = st . checkbox ( " Create Prompt Matrix. " , value = st . session_state [ ' defaults ' ] . img2img . separate_prompts ,
help = " Separate multiple prompts using the `|` character, and get all combinations of them. " )
normalize_prompt_weights = st . checkbox ( " Normalize Prompt Weights. " , value = st . session_state [ ' defaults ' ] . img2img . normalize_prompt_weights ,
help = " Ensure the sum of all weights add up to 1.0 " )
loopback = st . checkbox ( " Loopback. " , value = st . session_state [ ' defaults ' ] . img2img . loopback , help = " Use images from previous batch when creating next batch. " )
random_seed_loopback = st . checkbox ( " Random loopback seed. " , value = st . session_state [ ' defaults ' ] . img2img . random_seed_loopback , help = " Random loopback seed " )
img2img_mask_restore = st . checkbox ( " Only modify regenerated parts of image " ,
value = st . session_state [ ' defaults ' ] . img2img . mask_restore ,
help = " Enable to restore the unmasked parts of the image with the input, may not blend as well but preserves detail " )
save_individual_images = st . checkbox ( " Save individual images. " , value = st . session_state [ ' defaults ' ] . img2img . save_individual_images ,
help = " Save each image generated before any filter or enhancement is applied. " )
save_grid = st . checkbox ( " Save grid " , value = st . session_state [ ' defaults ' ] . img2img . save_grid , help = " Save a grid with all the images generated into a single image. " )
group_by_prompt = st . checkbox ( " Group results by prompt " , value = st . session_state [ ' defaults ' ] . img2img . group_by_prompt ,
help = " Saves all the images with the same prompt into the same folder. \
When using a prompt matrix each prompt combination will have its own folder . " )
2022-10-06 11:46:41 +03:00
write_info_files = st . checkbox ( " Write Info file " , value = st . session_state [ ' defaults ' ] . img2img . write_info_files ,
help = " Save a file next to the image with informartion about the generation. " )
2022-10-02 23:10:17 +03:00
save_as_jpg = st . checkbox ( " Save samples as jpg " , value = st . session_state [ ' defaults ' ] . img2img . save_as_jpg , help = " Saves the images as jpg instead of png. " )
2022-10-06 11:46:41 +03:00
2022-10-02 06:18:09 +03:00
#
2022-10-02 01:40:50 +03:00
# check if GFPGAN, RealESRGAN and LDSR are available.
if " GFPGAN_available " not in st . session_state :
GFPGAN_available ( )
2022-10-06 11:46:41 +03:00
2022-10-02 01:40:50 +03:00
if " RealESRGAN_available " not in st . session_state :
RealESRGAN_available ( )
2022-10-06 11:46:41 +03:00
2022-10-02 06:18:09 +03:00
if " LDSR_available " not in st . session_state :
LDSR_available ( )
2022-10-06 11:46:41 +03:00
2022-10-02 06:18:09 +03:00
if st . session_state [ " GFPGAN_available " ] or st . session_state [ " RealESRGAN_available " ] or st . session_state [ " LDSR_available " ] :
2022-10-02 01:40:50 +03:00
with st . expander ( " Post-Processing " ) :
2022-10-02 06:18:09 +03:00
face_restoration_tab , upscaling_tab = st . tabs ( [ " Face Restoration " , " Upscaling " ] )
with face_restoration_tab :
# GFPGAN used for face restoration
if st . session_state [ " GFPGAN_available " ] :
#with st.expander("Face Restoration"):
#if st.session_state["GFPGAN_available"]:
#with st.expander("GFPGAN"):
2022-10-03 06:26:01 +03:00
st . session_state [ " use_GFPGAN " ] = st . checkbox ( " Use GFPGAN " , value = st . session_state [ ' defaults ' ] . img2img . use_GFPGAN ,
2022-10-02 06:18:09 +03:00
help = " Uses the GFPGAN model to improve faces after the generation. \
This greatly improve the quality and consistency of faces but uses \
extra VRAM . Disable if you need the extra VRAM . " )
2022-10-06 11:46:41 +03:00
2022-10-02 01:40:50 +03:00
st . session_state [ " GFPGAN_model " ] = st . selectbox ( " GFPGAN model " , st . session_state [ " GFPGAN_models " ] ,
2022-10-06 11:46:41 +03:00
index = st . session_state [ " GFPGAN_models " ] . index ( st . session_state [ ' defaults ' ] . general . GFPGAN_model ) )
2022-10-02 01:40:50 +03:00
#st.session_state["GFPGAN_strenght"] = st.slider("Effect Strenght", min_value=1, max_value=100, value=1, step=1, help='')
2022-10-06 11:46:41 +03:00
2022-10-02 01:40:50 +03:00
else :
2022-10-06 11:46:41 +03:00
st . session_state [ " use_GFPGAN " ] = False
2022-10-02 06:18:09 +03:00
with upscaling_tab :
2022-10-03 06:26:01 +03:00
st . session_state [ ' us_upscaling ' ] = st . checkbox ( " Use Upscaling " , value = st . session_state [ ' defaults ' ] . img2img . use_upscaling )
2022-10-06 11:46:41 +03:00
# RealESRGAN and LDSR used for upscaling.
2022-10-02 06:18:09 +03:00
if st . session_state [ " RealESRGAN_available " ] or st . session_state [ " LDSR_available " ] :
2022-10-06 11:46:41 +03:00
2022-10-02 06:18:09 +03:00
upscaling_method_list = [ ]
if st . session_state [ " RealESRGAN_available " ] :
upscaling_method_list . append ( " RealESRGAN " )
if st . session_state [ " LDSR_available " ] :
upscaling_method_list . append ( " LDSR " )
2022-10-06 11:46:41 +03:00
2022-10-02 06:18:09 +03:00
st . session_state [ " upscaling_method " ] = st . selectbox ( " Upscaling Method " , upscaling_method_list ,
index = upscaling_method_list . index ( st . session_state [ ' defaults ' ] . general . upscaling_method ) )
2022-10-06 11:46:41 +03:00
2022-10-02 06:18:09 +03:00
if st . session_state [ " RealESRGAN_available " ] :
2022-10-03 06:26:01 +03:00
with st . expander ( " RealESRGAN " ) :
if st . session_state [ " upscaling_method " ] == " RealESRGAN " and st . session_state [ ' us_upscaling ' ] :
st . session_state [ " use_RealESRGAN " ] = True
else :
st . session_state [ " use_RealESRGAN " ] = False
2022-10-06 11:46:41 +03:00
2022-10-03 06:26:01 +03:00
st . session_state [ " RealESRGAN_model " ] = st . selectbox ( " RealESRGAN model " , st . session_state [ " RealESRGAN_models " ] ,
2022-10-06 11:46:41 +03:00
index = st . session_state [ " RealESRGAN_models " ] . index ( st . session_state [ ' defaults ' ] . general . RealESRGAN_model ) )
2022-10-02 06:18:09 +03:00
else :
st . session_state [ " use_RealESRGAN " ] = False
st . session_state [ " RealESRGAN_model " ] = " RealESRGAN_x4plus "
2022-10-06 11:46:41 +03:00
2022-10-02 06:18:09 +03:00
#
if st . session_state [ " LDSR_available " ] :
2022-10-03 06:26:01 +03:00
with st . expander ( " LDSR " ) :
if st . session_state [ " upscaling_method " ] == " LDSR " and st . session_state [ ' us_upscaling ' ] :
st . session_state [ " use_LDSR " ] = True
else :
st . session_state [ " use_LDSR " ] = False
2022-10-06 11:46:41 +03:00
2022-10-03 06:26:01 +03:00
st . session_state [ " LDSR_model " ] = st . selectbox ( " LDSR model " , st . session_state [ " LDSR_models " ] ,
2022-10-06 11:46:41 +03:00
index = st . session_state [ " LDSR_models " ] . index ( st . session_state [ ' defaults ' ] . general . LDSR_model ) )
2022-10-08 03:57:06 +03:00
st . session_state [ " ldsr_sampling_steps " ] = st . number_input ( " Sampling Steps " , value = st . session_state [ ' defaults ' ] . img2img . LDSR_config . sampling_steps ,
help = " " )
2022-10-06 11:46:41 +03:00
2022-10-08 03:57:06 +03:00
st . session_state [ " preDownScale " ] = st . number_input ( " PreDownScale " , value = st . session_state [ ' defaults ' ] . img2img . LDSR_config . preDownScale ,
help = " " )
2022-10-06 11:46:41 +03:00
2022-10-08 03:57:06 +03:00
st . session_state [ " postDownScale " ] = st . number_input ( " postDownScale " , value = st . session_state [ ' defaults ' ] . img2img . LDSR_config . postDownScale ,
help = " " )
2022-10-06 11:46:41 +03:00
2022-10-03 06:26:01 +03:00
downsample_method_list = [ ' Nearest ' , ' Lanczos ' ]
st . session_state [ " downsample_method " ] = st . selectbox ( " Downsample Method " , downsample_method_list ,
index = downsample_method_list . index ( st . session_state [ ' defaults ' ] . img2img . LDSR_config . downsample_method ) )
2022-10-06 11:46:41 +03:00
2022-10-02 06:18:09 +03:00
else :
st . session_state [ " use_LDSR " ] = False
2022-10-06 11:46:41 +03:00
st . session_state [ " LDSR_model " ] = " model "
2022-09-26 14:53:34 +03:00
with st . expander ( " Variant " ) :
variant_amount = st . slider ( " Variant Amount: " , value = st . session_state [ ' defaults ' ] . img2img . variant_amount , min_value = 0.0 , max_value = 1.0 , step = 0.01 )
variant_seed = st . text_input ( " Variant Seed: " , value = st . session_state [ ' defaults ' ] . img2img . variant_seed ,
help = " The seed to use when generating a variant, if left blank a random seed will be generated. " )
2022-10-06 11:46:41 +03:00
2022-09-14 00:08:40 +03:00
with col2_img2img_layout :
editor_tab = st . tabs ( [ " Editor " ] )
2022-10-06 11:46:41 +03:00
2022-09-14 00:08:40 +03:00
editor_image = st . empty ( )
st . session_state [ " editor_image " ] = editor_image
2022-10-06 11:46:41 +03:00
2022-09-14 00:08:40 +03:00
masked_image_holder = st . empty ( )
image_holder = st . empty ( )
2022-10-06 11:46:41 +03:00
2022-10-03 06:26:01 +03:00
st . form_submit_button ( " Refresh " )
2022-10-06 11:46:41 +03:00
2022-09-14 00:08:40 +03:00
uploaded_images = st . file_uploader (
2022-09-16 22:55:58 +03:00
" Upload Image " , accept_multiple_files = False , type = [ " png " , " jpg " , " jpeg " , " webp " ] ,
2022-09-14 00:08:40 +03:00
help = " Upload an image which will be used for the image to image generation. " ,
)
if uploaded_images :
image = Image . open ( uploaded_images ) . convert ( ' RGBA ' )
new_img = image . resize ( ( width , height ) )
image_holder . image ( new_img )
2022-10-06 11:46:41 +03:00
2022-09-14 00:08:40 +03:00
mask_holder = st . empty ( )
2022-10-06 11:46:41 +03:00
2022-09-14 00:08:40 +03:00
uploaded_masks = st . file_uploader (
2022-09-16 22:55:58 +03:00
" Upload Mask " , accept_multiple_files = False , type = [ " png " , " jpg " , " jpeg " , " webp " ] ,
2022-09-14 00:08:40 +03:00
help = " Upload an mask image which will be used for masking the image to image generation. " ,
)
if uploaded_masks :
2022-09-26 14:53:34 +03:00
mask_expander . expander ( " Mask " , expanded = True )
2022-09-14 00:08:40 +03:00
mask = Image . open ( uploaded_masks )
if mask . mode == " RGBA " :
mask = mask . convert ( ' RGBA ' )
background = Image . new ( ' RGBA ' , mask . size , ( 0 , 0 , 0 ) )
mask = Image . alpha_composite ( background , mask )
mask = mask . resize ( ( width , height ) )
mask_holder . image ( mask )
2022-10-06 11:46:41 +03:00
2022-09-14 00:08:40 +03:00
if uploaded_images and uploaded_masks :
if mask_mode != 2 :
final_img = new_img . copy ( )
alpha_layer = mask . convert ( ' L ' )
strength = st . session_state [ " denoising_strength " ]
if mask_mode == 0 :
alpha_layer = ImageOps . invert ( alpha_layer )
alpha_layer = alpha_layer . point ( lambda a : a * strength )
alpha_layer = ImageOps . invert ( alpha_layer )
elif mask_mode == 1 :
alpha_layer = alpha_layer . point ( lambda a : a * strength )
alpha_layer = ImageOps . invert ( alpha_layer )
2022-10-06 11:46:41 +03:00
2022-09-14 00:08:40 +03:00
final_img . putalpha ( alpha_layer )
2022-10-06 11:46:41 +03:00
2022-09-14 00:08:40 +03:00
with masked_image_holder . container ( ) :
st . text ( " Masked Image Preview " )
st . image ( final_img )
2022-10-06 11:46:41 +03:00
2022-09-14 00:08:40 +03:00
with col3_img2img_layout :
result_tab = st . tabs ( [ " Result " ] )
2022-10-06 11:46:41 +03:00
2022-09-14 00:08:40 +03:00
# create an empty container for the image, progress bar, etc so we can update it later and use session_state to hold them globally.
preview_image = st . empty ( )
st . session_state [ " preview_image " ] = preview_image
2022-10-06 11:46:41 +03:00
2022-09-14 00:08:40 +03:00
#st.session_state["loading"] = st.empty()
2022-10-06 11:46:41 +03:00
2022-09-14 00:08:40 +03:00
st . session_state [ " progress_bar_text " ] = st . empty ( )
st . session_state [ " progress_bar " ] = st . empty ( )
2022-10-06 11:46:41 +03:00
2022-09-14 00:08:40 +03:00
message = st . empty ( )
2022-10-06 11:46:41 +03:00
2022-09-14 00:08:40 +03:00
#if uploaded_images:
#image = Image.open(uploaded_images).convert('RGB')
##img_array = np.array(image) # if you want to pass it to OpenCV
#new_img = image.resize((width, height))
#st.image(new_img, use_column_width=True)
2022-10-06 11:46:41 +03:00
2022-09-14 00:08:40 +03:00
if generate_button :
#print("Loading models")
# load the models when we hit the generate button for the first time, it wont be loaded after that so dont worry.
2022-09-29 13:29:44 +03:00
with col3_img2img_layout :
with hc . HyLoader ( ' Loading Models... ' , hc . Loaders . standard_loaders , index = [ 0 ] ) :
2022-10-03 06:26:01 +03:00
load_models ( use_LDSR = st . session_state [ " use_LDSR " ] , LDSR_model = st . session_state [ " LDSR_model " ] ,
2022-10-06 11:46:41 +03:00
use_GFPGAN = st . session_state [ " use_GFPGAN " ] , GFPGAN_model = st . session_state [ " GFPGAN_model " ] ,
use_RealESRGAN = st . session_state [ " use_RealESRGAN " ] , RealESRGAN_model = st . session_state [ " RealESRGAN_model " ] ,
CustomModel_available = server_state [ " CustomModel_available " ] , custom_model = st . session_state [ " custom_model " ] )
2022-09-14 00:08:40 +03:00
if uploaded_images :
image = Image . open ( uploaded_images ) . convert ( ' RGBA ' )
new_img = image . resize ( ( width , height ) )
#img_array = np.array(image) # if you want to pass it to OpenCV
new_mask = None
if uploaded_masks :
mask = Image . open ( uploaded_masks ) . convert ( ' RGBA ' )
new_mask = mask . resize ( ( width , height ) )
2022-10-06 11:46:41 +03:00
2022-09-14 00:08:40 +03:00
try :
output_images , seed , info , stats = img2img ( prompt = prompt , init_info = new_img , init_info_mask = new_mask , mask_mode = mask_mode ,
2022-09-14 15:48:21 +03:00
mask_restore = img2img_mask_restore , ddim_steps = st . session_state [ " sampling_steps " ] ,
2022-10-02 23:10:17 +03:00
sampler_name = st . session_state [ " sampler_name " ] , n_iter = st . session_state [ " batch_count " ] ,
2022-09-14 00:08:40 +03:00
cfg_scale = cfg_scale , denoising_strength = st . session_state [ " denoising_strength " ] , variant_seed = variant_seed ,
2022-10-06 11:46:41 +03:00
seed = seed , noise_mode = noise_mode , find_noise_steps = find_noise_steps , width = width ,
height = height , variant_amount = variant_amount ,
2022-09-18 21:11:23 +03:00
ddim_eta = st . session_state . defaults . img2img . ddim_eta , write_info_files = write_info_files ,
2022-09-14 00:08:40 +03:00
separate_prompts = separate_prompts , normalize_prompt_weights = normalize_prompt_weights ,
2022-10-06 11:46:41 +03:00
save_individual_images = save_individual_images , save_grid = save_grid ,
2022-10-02 01:40:50 +03:00
group_by_prompt = group_by_prompt , save_as_jpg = save_as_jpg , use_GFPGAN = st . session_state [ " use_GFPGAN " ] ,
2022-10-02 06:18:09 +03:00
GFPGAN_model = st . session_state [ " GFPGAN_model " ] ,
use_RealESRGAN = st . session_state [ " use_RealESRGAN " ] , RealESRGAN_model = st . session_state [ " RealESRGAN_model " ] ,
use_LDSR = st . session_state [ " use_LDSR " ] , LDSR_model = st . session_state [ " LDSR_model " ] ,
loopback = loopback
2022-09-14 00:08:40 +03:00
)
2022-10-06 11:46:41 +03:00
2022-09-14 00:08:40 +03:00
#show a message when the generation is complete.
message . success ( ' Render Complete: ' + info + ' ; Stats: ' + stats , icon = " ✅ " )
2022-10-06 11:46:41 +03:00
2022-09-14 00:08:40 +03:00
except ( StopException , KeyError ) :
2022-10-15 15:34:07 +03:00
logger . info ( f " Received Streamlit StopException " )
2022-10-06 11:46:41 +03:00
2022-09-14 00:08:40 +03:00
# this will render all the images at the end of the generation but its better if its moved to a second tab inside col2 and shown as a gallery.
# use the current col2 first tab to show the preview_img and update it as its generated.
#preview_image.image(output_images, width=750)
#on import run init