2022-09-03 12:08:45 +03:00
import html
import json
2022-09-10 11:10:00 +03:00
import math
2022-09-03 12:08:45 +03:00
import mimetypes
import os
2022-10-22 14:07:00 +03:00
import platform
2022-09-05 23:08:06 +03:00
import random
2022-10-22 14:07:00 +03:00
import subprocess as sp
2022-09-03 12:08:45 +03:00
import sys
2022-10-15 13:11:28 +03:00
import tempfile
2022-09-03 12:08:45 +03:00
import time
import traceback
2022-10-14 20:04:47 +03:00
from functools import partial , reduce
2023-01-18 23:04:24 +03:00
import warnings
2022-09-03 12:08:45 +03:00
2022-10-22 14:07:00 +03:00
import gradio as gr
import gradio . routes
import gradio . utils
2022-09-06 19:33:51 +03:00
import numpy as np
2022-09-28 17:05:23 +03:00
from PIL import Image , PngImagePlugin
2022-11-28 09:00:10 +03:00
from modules . call_queue import wrap_gradio_gpu_call , wrap_queued_call , wrap_gradio_call
2022-09-03 12:08:45 +03:00
2023-01-23 09:24:43 +03:00
from modules import sd_hijack , sd_models , localization , script_callbacks , ui_extensions , deepbooru , sd_vae , extra_networks , postprocessing , ui_components , ui_common , ui_postprocessing
2023-01-07 09:56:37 +03:00
from modules . ui_components import FormRow , FormGroup , ToolButton , FormHTML
2022-09-03 12:08:45 +03:00
from modules . paths import script_path
2022-10-16 22:06:21 +03:00
2022-10-16 20:08:23 +03:00
from modules . shared import opts , cmd_opts , restricted_opts
2022-10-14 11:56:41 +03:00
2022-10-22 14:07:00 +03:00
import modules . codeformer_model
2022-10-27 08:36:11 +03:00
import modules . generation_parameters_copypaste as parameters_copypaste
2022-10-22 14:07:00 +03:00
import modules . gfpgan_model
import modules . hypernetworks . ui
2022-09-03 17:21:15 +03:00
import modules . scripts
2022-10-22 14:07:00 +03:00
import modules . shared as shared
2022-09-09 23:16:02 +03:00
import modules . styles
2022-10-22 14:07:00 +03:00
import modules . textual_inversion . ui
2022-10-05 23:16:27 +03:00
from modules import prompt_parser
2022-10-04 19:19:50 +03:00
from modules . images import save_image
2022-10-22 14:07:00 +03:00
from modules . sd_hijack import model_hijack
from modules . sd_samplers import samplers , samplers_for_img2img
2023-01-09 23:35:40 +03:00
from modules . textual_inversion import textual_inversion
2022-10-11 15:51:22 +03:00
import modules . hypernetworks . ui
2022-10-27 08:36:11 +03:00
from modules . generation_parameters_copypaste import image_from_url_text
2023-01-23 14:42:49 +03:00
import modules . extras
2022-09-03 12:08:45 +03:00
2023-01-18 23:04:24 +03:00
warnings . filterwarnings ( " default " if opts . show_warnings else " ignore " , category = UserWarning )
2022-10-08 22:12:24 +03:00
# this is a fix for Windows users. Without it, javascript files will be served with text/html content-type and the browser will not show any UI
2022-09-03 12:08:45 +03:00
mimetypes . init ( )
mimetypes . add_type ( ' application/javascript ' , ' .js ' )
2022-09-05 19:37:11 +03:00
if not cmd_opts . share and not cmd_opts . listen :
2022-09-03 12:08:45 +03:00
# fix gradio phoning home
gradio . utils . version_check = lambda : None
gradio . utils . get_local_ip_address = lambda : ' 127.0.0.1 '
2022-12-14 21:59:33 +03:00
if cmd_opts . ngrok is not None :
2022-10-11 12:40:27 +03:00
import modules . ngrok as ngrok
print ( ' ngrok authtoken detected, trying to connect... ' )
2022-12-14 21:59:33 +03:00
ngrok . connect (
cmd_opts . ngrok ,
cmd_opts . port if cmd_opts . port is not None else 7860 ,
cmd_opts . ngrok_region
)
2022-10-11 12:40:27 +03:00
2022-09-03 12:08:45 +03:00
def gr_show ( visible = True ) :
return { " visible " : visible , " __type__ " : " update " }
sample_img2img = " assets/stable-samples/img2img/sketch-mountains-input.jpg "
sample_img2img = sample_img2img if os . path . exists ( sample_img2img ) else None
css_hide_progressbar = """
. wrap . m - 12 svg { display : none ! important ; }
2022-09-27 10:44:00 +03:00
. wrap . m - 12 : : before { content : " Loading... " }
2022-11-16 07:08:03 +03:00
. wrap . z - 20 svg { display : none ! important ; }
. wrap . z - 20 : : before { content : " Loading... " }
2023-01-21 23:06:18 +03:00
. wrap . cover - bg . z - 20 : : before { content : " " }
2022-09-03 12:08:45 +03:00
. progress - bar { display : none ! important ; }
. meta - text { display : none ! important ; }
2022-11-16 07:08:03 +03:00
. meta - text - center { display : none ! important ; }
2022-09-03 12:08:45 +03:00
"""
2022-09-16 22:20:56 +03:00
# Using constants for these since the variation selector isn't visible.
# Important that they exactly match script.js for tooltip to work.
random_symbol = ' \U0001f3b2 \ufe0f ' # 🎲️
reuse_symbol = ' \u267b \ufe0f ' # ♻️
2022-09-23 22:49:21 +03:00
paste_symbol = ' \u2199 \ufe0f ' # ↙
2022-10-13 19:22:41 +03:00
refresh_symbol = ' \U0001f504 ' # 🔄
2022-10-15 14:22:30 +03:00
save_style_symbol = ' \U0001f4be ' # 💾
apply_style_symbol = ' \U0001f4cb ' # 📋
2022-10-21 23:32:26 +03:00
clear_prompt_symbol = ' \U0001F5D1 ' # 🗑️
2023-01-21 08:36:07 +03:00
extra_networks_symbol = ' \U0001F3B4 ' # 🎴
2022-10-13 19:22:41 +03:00
2022-09-16 22:20:56 +03:00
2022-09-03 12:08:45 +03:00
def plaintext_to_html ( text ) :
2023-01-23 09:24:43 +03:00
return ui_common . plaintext_to_html ( text )
2023-01-22 15:38:39 +03:00
2022-09-03 12:08:45 +03:00
def send_gradio_gallery_to_image ( x ) :
if len ( x ) == 0 :
return None
return image_from_url_text ( x [ 0 ] )
2022-09-04 13:52:01 +03:00
def visit ( x , func , path = " " ) :
if hasattr ( x , ' children ' ) :
for c in x . children :
visit ( c , func , path )
elif x . label is not None :
func ( path + " / " + str ( x . label ) , x )
2022-09-03 12:08:45 +03:00
2022-09-11 17:35:12 +03:00
def add_style ( name : str , prompt : str , negative_prompt : str ) :
if name is None :
2022-10-15 14:22:30 +03:00
return [ gr_show ( ) for x in range ( 4 ) ]
2022-09-09 23:16:02 +03:00
2022-09-11 17:35:12 +03:00
style = modules . styles . PromptStyle ( name , prompt , negative_prompt )
2022-09-14 17:56:21 +03:00
shared . prompt_styles . styles [ style . name ] = style
2022-09-11 17:35:12 +03:00
# Save all loaded prompt styles: this allows us to update the storage format in the future more easily, because we
# reserialize all styles every time we save them
2022-09-14 17:56:21 +03:00
shared . prompt_styles . save_styles ( shared . styles_filename )
2022-09-09 23:16:02 +03:00
2023-01-14 14:56:39 +03:00
return [ gr . Dropdown . update ( visible = True , choices = list ( shared . prompt_styles . styles ) ) for _ in range ( 2 ) ]
2022-09-14 17:56:21 +03:00
2023-01-07 09:56:37 +03:00
def calc_resolution_hires ( enable , width , height , hr_scale , hr_resize_x , hr_resize_y ) :
from modules import processing , devices
if not enable :
return " "
p = processing . StableDiffusionProcessingTxt2Img ( width = width , height = height , enable_hr = True , hr_scale = hr_scale , hr_resize_x = hr_resize_x , hr_resize_y = hr_resize_y )
with devices . autocast ( ) :
p . init ( [ " " ] , [ 0 ] , [ 0 ] )
2023-01-09 14:57:47 +03:00
return f " resize: from <span class= ' resolution ' > { p . width } x { p . height } </span> to <span class= ' resolution ' > { p . hr_resize_x or p . hr_upscale_to_x } x { p . hr_resize_y or p . hr_upscale_to_y } </span> "
2023-01-07 09:56:37 +03:00
2022-09-14 17:56:21 +03:00
2023-01-14 14:56:39 +03:00
def apply_styles ( prompt , prompt_neg , styles ) :
prompt = shared . prompt_styles . apply_styles_to_prompt ( prompt , styles )
prompt_neg = shared . prompt_styles . apply_negative_styles_to_prompt ( prompt_neg , styles )
2022-09-14 17:56:21 +03:00
2023-01-14 14:56:39 +03:00
return [ gr . Textbox . update ( value = prompt ) , gr . Textbox . update ( value = prompt_neg ) , gr . Dropdown . update ( value = [ ] ) ]
2022-09-09 23:16:02 +03:00
2023-01-18 20:16:52 +03:00
def process_interrogate ( interrogation_function , mode , ii_input_dir , ii_output_dir , * ii_singles ) :
if mode in { 0 , 1 , 3 , 4 } :
return [ interrogation_function ( ii_singles [ mode ] ) , None ]
elif mode == 2 :
return [ interrogation_function ( ii_singles [ mode ] [ " image " ] ) , None ]
elif mode == 5 :
assert not shared . cmd_opts . hide_ui_dir_config , " Launched with --hide-ui-dir-config, batch img2img disabled "
images = shared . listfiles ( ii_input_dir )
print ( f " Will process { len ( images ) } images. " )
if ii_output_dir != " " :
os . makedirs ( ii_output_dir , exist_ok = True )
else :
ii_output_dir = ii_input_dir
for image in images :
img = Image . open ( image )
filename = os . path . basename ( image )
left , _ = os . path . splitext ( filename )
print ( interrogation_function ( img ) , file = open ( os . path . join ( ii_output_dir , left + " .txt " ) , ' a ' ) )
2023-01-21 09:14:27 +03:00
return [ gr . update ( ) , None ]
2023-01-18 20:16:52 +03:00
2022-09-11 18:48:36 +03:00
def interrogate ( image ) :
2022-12-25 07:23:12 +03:00
prompt = shared . interrogator . interrogate ( image . convert ( " RGB " ) )
2023-01-21 09:14:27 +03:00
return gr . update ( ) if prompt is None else prompt
2022-09-11 18:48:36 +03:00
2022-09-14 17:56:21 +03:00
2022-10-05 21:50:10 +03:00
def interrogate_deepbooru ( image ) :
2022-11-20 16:39:20 +03:00
prompt = deepbooru . model . tag ( image )
2023-01-21 09:14:27 +03:00
return gr . update ( ) if prompt is None else prompt
2022-10-05 21:50:10 +03:00
2023-01-01 16:51:12 +03:00
def create_seed_inputs ( target_interface ) :
2023-01-03 09:04:29 +03:00
with FormRow ( elem_id = target_interface + ' _seed_row ' ) :
seed = ( gr . Textbox if cmd_opts . use_textbox_seed else gr . Number ) ( label = ' Seed ' , value = - 1 , elem_id = target_interface + ' _seed ' )
seed . style ( container = False )
random_seed = gr . Button ( random_symbol , elem_id = target_interface + ' _random_seed ' )
reuse_seed = gr . Button ( reuse_symbol , elem_id = target_interface + ' _reuse_seed ' )
2022-09-16 22:20:56 +03:00
2023-01-03 09:04:29 +03:00
with gr . Group ( elem_id = target_interface + ' _subseed_show_box ' ) :
2023-01-01 16:51:12 +03:00
seed_checkbox = gr . Checkbox ( label = ' Extra ' , elem_id = target_interface + ' _subseed_show ' , value = False )
2022-09-16 22:20:56 +03:00
# Components to show/hide based on the 'Extra' checkbox
seed_extras = [ ]
2023-01-03 09:04:29 +03:00
with FormRow ( visible = False , elem_id = target_interface + ' _subseed_row ' ) as seed_extra_row_1 :
2022-09-16 22:20:56 +03:00
seed_extras . append ( seed_extra_row_1 )
2023-01-03 09:04:29 +03:00
subseed = gr . Number ( label = ' Variation seed ' , value = - 1 , elem_id = target_interface + ' _subseed ' )
subseed . style ( container = False )
random_subseed = gr . Button ( random_symbol , elem_id = target_interface + ' _random_subseed ' )
reuse_subseed = gr . Button ( reuse_symbol , elem_id = target_interface + ' _reuse_subseed ' )
2023-01-01 16:51:12 +03:00
subseed_strength = gr . Slider ( label = ' Variation strength ' , value = 0.0 , minimum = 0 , maximum = 1 , step = 0.01 , elem_id = target_interface + ' _subseed_strength ' )
2022-09-16 22:20:56 +03:00
2023-01-03 09:04:29 +03:00
with FormRow ( visible = False ) as seed_extra_row_2 :
2022-09-16 22:20:56 +03:00
seed_extras . append ( seed_extra_row_2 )
2023-01-01 16:51:12 +03:00
seed_resize_from_w = gr . Slider ( minimum = 0 , maximum = 2048 , step = 8 , label = " Resize seed from width " , value = 0 , elem_id = target_interface + ' _seed_resize_from_w ' )
seed_resize_from_h = gr . Slider ( minimum = 0 , maximum = 2048 , step = 8 , label = " Resize seed from height " , value = 0 , elem_id = target_interface + ' _seed_resize_from_h ' )
2022-09-16 22:20:56 +03:00
random_seed . click ( fn = lambda : - 1 , show_progress = False , inputs = [ ] , outputs = [ seed ] )
random_subseed . click ( fn = lambda : - 1 , show_progress = False , inputs = [ ] , outputs = [ subseed ] )
def change_visibility ( show ) :
return { comp : gr_show ( show ) for comp in seed_extras }
seed_checkbox . change ( change_visibility , show_progress = False , inputs = [ seed_checkbox ] , outputs = seed_extras )
2022-09-21 13:34:10 +03:00
return seed , reuse_seed , subseed , reuse_subseed , subseed_strength , seed_resize_from_h , seed_resize_from_w , seed_checkbox
2022-09-16 22:20:56 +03:00
2022-10-21 22:24:14 +03:00
2022-11-01 22:33:55 +03:00
def connect_clear_prompt ( button ) :
2022-10-21 23:32:26 +03:00
""" Given clear button, prompt, and token_counter objects, setup clear prompt button click event """
2022-10-21 22:24:14 +03:00
button . click (
2022-10-21 23:13:12 +03:00
_js = " clear_prompt " ,
2022-11-01 22:03:56 +03:00
fn = None ,
2022-11-01 22:33:55 +03:00
inputs = [ ] ,
outputs = [ ] ,
2022-10-21 22:24:14 +03:00
)
2022-10-20 09:08:24 +03:00
2022-09-19 09:02:10 +03:00
def connect_reuse_seed ( seed : gr . Number , reuse_seed : gr . Button , generation_info : gr . Textbox , dummy_component , is_subseed ) :
""" Connects a ' reuse (sub)seed ' button ' s click event so that it copies last used
( sub ) seed value from generation info the to the seed field . If copying subseed and subseed strength
2022-09-16 22:20:56 +03:00
was 0 , i . e . no variation seed was used , it copies the normal seed value instead . """
2022-09-19 09:02:10 +03:00
def copy_seed ( gen_info_string : str , index ) :
res = - 1
2022-09-16 22:20:56 +03:00
try :
gen_info = json . loads ( gen_info_string )
2022-09-19 09:02:10 +03:00
index - = gen_info . get ( ' index_of_first_image ' , 0 )
if is_subseed and gen_info . get ( ' subseed_strength ' , 0 ) > 0 :
all_subseeds = gen_info . get ( ' all_subseeds ' , [ - 1 ] )
res = all_subseeds [ index if 0 < = index < len ( all_subseeds ) else 0 ]
2022-09-16 22:20:56 +03:00
else :
2022-09-19 09:02:10 +03:00
all_seeds = gen_info . get ( ' all_seeds ' , [ - 1 ] )
res = all_seeds [ index if 0 < = index < len ( all_seeds ) else 0 ]
2022-09-16 22:20:56 +03:00
except json . decoder . JSONDecodeError as e :
if gen_info_string != ' ' :
print ( " Error parsing JSON generation info: " , file = sys . stderr )
print ( gen_info_string , file = sys . stderr )
2022-09-19 09:02:10 +03:00
return [ res , gr_show ( False ) ]
2022-09-16 22:20:56 +03:00
reuse_seed . click (
fn = copy_seed ,
2022-09-19 09:02:10 +03:00
_js = " (x, y) => [x, selected_gallery_index()] " ,
2022-09-16 22:20:56 +03:00
show_progress = False ,
2022-09-19 09:02:10 +03:00
inputs = [ generation_info , dummy_component ] ,
outputs = [ seed , dummy_component ]
2022-09-16 22:20:56 +03:00
)
2022-10-04 14:35:12 +03:00
2022-09-29 22:47:06 +03:00
def update_token_counter ( text , steps ) :
2022-10-04 14:35:12 +03:00
try :
2023-01-21 08:36:07 +03:00
text , _ = extra_networks . parse_prompt ( text )
2022-10-05 23:16:27 +03:00
_ , prompt_flat_list , _ = prompt_parser . get_multicond_prompt_list ( [ text ] )
prompt_schedules = prompt_parser . get_learned_conditioning_prompt_schedules ( prompt_flat_list , steps )
2022-10-04 14:35:12 +03:00
except Exception :
# a parsing error can happen here during typing, and we don't want to bother the user with
# messages related to it in console
prompt_schedules = [ [ [ steps , text ] ] ]
2022-09-29 22:47:06 +03:00
flat_prompts = reduce ( lambda list1 , list2 : list1 + list2 , prompt_schedules )
2022-10-04 14:35:12 +03:00
prompts = [ prompt_text for step , prompt_text in flat_prompts ]
2023-01-07 01:45:28 +03:00
token_count , max_length = max ( [ model_hijack . get_prompt_lengths ( prompt ) for prompt in prompts ] , key = lambda args : args [ 0 ] )
2023-01-20 10:18:41 +03:00
return f " <span class= ' gr-box gr-text-input ' > { token_count } / { max_length } </span> "
2022-09-19 16:42:56 +03:00
2022-10-04 14:35:12 +03:00
2022-09-14 17:56:21 +03:00
def create_toprow ( is_img2img ) :
2022-09-23 20:46:02 +03:00
id_part = " img2img " if is_img2img else " txt2img "
2023-01-20 10:18:41 +03:00
with gr . Row ( elem_id = f " { id_part } _toprow " , variant = " compact " ) :
with gr . Column ( elem_id = f " { id_part } _prompt_container " , scale = 6 ) :
2022-09-14 17:56:21 +03:00
with gr . Row ( ) :
2022-09-23 20:54:17 +03:00
with gr . Column ( scale = 80 ) :
2022-09-14 17:56:21 +03:00
with gr . Row ( ) :
2023-01-21 09:48:38 +03:00
prompt = gr . Textbox ( label = " Prompt " , elem_id = f " { id_part } _prompt " , show_label = False , lines = 3 , placeholder = " Prompt (press Ctrl+Enter or Alt+Enter to generate) " )
2022-10-15 06:48:13 +03:00
2022-09-14 17:56:21 +03:00
with gr . Row ( ) :
2022-10-15 14:22:30 +03:00
with gr . Column ( scale = 80 ) :
2022-10-11 10:08:45 +03:00
with gr . Row ( ) :
2023-01-20 10:18:41 +03:00
negative_prompt = gr . Textbox ( label = " Negative prompt " , elem_id = f " { id_part } _neg_prompt " , show_label = False , lines = 2 , placeholder = " Negative prompt (press Ctrl+Enter or Alt+Enter to generate) " )
2022-10-15 06:48:13 +03:00
2022-10-15 14:22:30 +03:00
button_interrogate = None
button_deepbooru = None
if is_img2img :
with gr . Column ( scale = 1 , elem_id = " interrogate_col " ) :
button_interrogate = gr . Button ( ' Interrogate \n CLIP ' , elem_id = " interrogate " )
2022-11-26 16:10:46 +03:00
button_deepbooru = gr . Button ( ' Interrogate \n DeepBooru ' , elem_id = " deepbooru " )
2022-09-14 17:56:21 +03:00
2023-01-21 09:48:38 +03:00
with gr . Column ( scale = 1 , elem_id = f " { id_part } _actions_column " ) :
2023-01-15 18:50:56 +03:00
with gr . Row ( elem_id = f " { id_part } _generate_box " ) :
2022-09-23 20:46:02 +03:00
interrupt = gr . Button ( ' Interrupt ' , elem_id = f " { id_part } _interrupt " )
2023-01-15 20:29:48 +03:00
skip = gr . Button ( ' Skip ' , elem_id = f " { id_part } _skip " )
2022-09-30 23:31:00 +03:00
submit = gr . Button ( ' Generate ' , elem_id = f " { id_part } _generate " , variant = ' primary ' )
2022-09-22 04:12:39 +03:00
2022-10-05 06:56:30 +03:00
skip . click (
fn = lambda : shared . state . skip ( ) ,
inputs = [ ] ,
outputs = [ ] ,
)
2022-09-22 04:12:39 +03:00
interrupt . click (
fn = lambda : shared . state . interrupt ( ) ,
inputs = [ ] ,
outputs = [ ] ,
)
2022-09-14 17:56:21 +03:00
2023-01-21 09:48:38 +03:00
with gr . Row ( elem_id = f " { id_part } _tools " ) :
paste = ToolButton ( value = paste_symbol , elem_id = " paste " )
clear_prompt_button = ToolButton ( value = clear_prompt_symbol , elem_id = f " { id_part } _clear_prompt " )
extra_networks_button = ToolButton ( value = extra_networks_symbol , elem_id = f " { id_part } _extra_networks " )
prompt_style_apply = ToolButton ( value = apply_style_symbol , elem_id = f " { id_part } _style_apply " )
save_style = ToolButton ( value = save_style_symbol , elem_id = f " { id_part } _style_create " )
token_counter = gr . HTML ( value = " <span></span> " , elem_id = f " { id_part } _token_counter " )
token_button = gr . Button ( visible = False , elem_id = f " { id_part } _token_button " )
negative_token_counter = gr . HTML ( value = " <span></span> " , elem_id = f " { id_part } _negative_token_counter " )
negative_token_button = gr . Button ( visible = False , elem_id = f " { id_part } _negative_token_button " )
clear_prompt_button . click (
fn = lambda * x : x ,
_js = " confirm_clear_prompt " ,
inputs = [ prompt , negative_prompt ] ,
outputs = [ prompt , negative_prompt ] ,
)
2023-01-21 08:36:07 +03:00
with gr . Row ( elem_id = f " { id_part } _styles_row " ) :
2023-01-14 14:56:39 +03:00
prompt_styles = gr . Dropdown ( label = " Styles " , elem_id = f " { id_part } _styles " , choices = [ k for k , v in shared . prompt_styles . styles . items ( ) ] , value = [ ] , multiselect = True )
create_refresh_button ( prompt_styles , shared . prompt_styles . reload , lambda : { " choices " : [ k for k , v in shared . prompt_styles . styles . items ( ) ] } , f " refresh_ { id_part } _styles " )
2022-10-15 14:22:30 +03:00
2023-01-21 08:36:07 +03:00
return prompt , prompt_styles , negative_prompt , submit , button_interrogate , button_deepbooru , prompt_style_apply , save_style , paste , extra_networks_button , token_counter , token_button , negative_token_counter , negative_token_button
2022-09-14 17:56:21 +03:00
2023-01-10 12:29:45 +03:00
def setup_progressbar ( * args , * * kwargs ) :
2023-01-15 18:50:56 +03:00
pass
2022-09-14 17:56:21 +03:00
2022-10-14 19:30:28 +03:00
def apply_setting ( key , value ) :
if value is None :
return gr . update ( )
2022-10-22 22:05:22 +03:00
if shared . cmd_opts . freeze_settings :
return gr . update ( )
2022-10-17 17:58:21 +03:00
# dont allow model to be swapped when model hash exists in prompt
if key == " sd_model_checkpoint " and opts . disable_weights_auto_swap :
return gr . update ( )
2022-10-14 19:30:28 +03:00
if key == " sd_model_checkpoint " :
ckpt_info = sd_models . get_closet_checkpoint_match ( value )
if ckpt_info is not None :
value = ckpt_info . title
else :
return gr . update ( )
comp_args = opts . data_labels [ key ] . component_args
if comp_args and isinstance ( comp_args , dict ) and comp_args . get ( ' visible ' ) is False :
return
valtype = type ( opts . data_labels [ key ] . default )
2023-01-02 20:46:51 +03:00
oldval = opts . data . get ( key , None )
2022-10-14 19:30:28 +03:00
opts . data [ key ] = valtype ( value ) if valtype != type ( None ) else value
if oldval != value and opts . data_labels [ key ] . onchange is not None :
opts . data_labels [ key ] . onchange ( )
opts . save ( shared . config_filename )
2023-01-19 18:58:08 +03:00
return getattr ( opts , key )
2022-10-14 19:30:28 +03:00
2022-10-21 16:10:51 +03:00
def create_refresh_button ( refresh_component , refresh_method , refreshed_args , elem_id ) :
def refresh ( ) :
refresh_method ( )
args = refreshed_args ( ) if callable ( refreshed_args ) else refreshed_args
2022-10-02 15:03:39 +03:00
2022-10-21 16:10:51 +03:00
for k , v in args . items ( ) :
setattr ( refresh_component , k , v )
2022-10-16 07:42:52 +03:00
2022-10-21 16:10:51 +03:00
return gr . update ( * * ( args or { } ) )
2022-10-16 07:42:52 +03:00
2023-01-03 09:04:29 +03:00
refresh_button = ToolButton ( value = refresh_symbol , elem_id = elem_id )
2022-10-21 16:10:51 +03:00
refresh_button . click (
fn = refresh ,
inputs = [ ] ,
outputs = [ refresh_component ]
)
return refresh_button
2022-10-16 07:42:52 +03:00
2022-10-29 08:28:48 +03:00
def create_output_panel ( tabname , outdir ) :
2023-01-23 09:24:43 +03:00
return ui_common . create_output_panel ( tabname , outdir )
2022-10-10 04:26:52 +03:00
2022-10-08 08:09:29 +03:00
2023-01-01 01:19:10 +03:00
def create_sampler_and_steps_selection ( choices , tabname ) :
if opts . samplers_in_dropdown :
2023-01-03 09:04:29 +03:00
with FormRow ( elem_id = f " sampler_selection_ { tabname } " ) :
2023-01-01 01:19:10 +03:00
sampler_index = gr . Dropdown ( label = ' Sampling method ' , elem_id = f " { tabname } _sampling " , choices = [ x . name for x in choices ] , value = choices [ 0 ] . name , type = " index " )
2023-01-04 22:04:40 +03:00
steps = gr . Slider ( minimum = 1 , maximum = 150 , step = 1 , elem_id = f " { tabname } _steps " , label = " Sampling steps " , value = 20 )
2023-01-01 01:19:10 +03:00
else :
2023-01-03 09:04:29 +03:00
with FormGroup ( elem_id = f " sampler_selection_ { tabname } " ) :
2023-01-04 22:04:40 +03:00
steps = gr . Slider ( minimum = 1 , maximum = 150 , step = 1 , elem_id = f " { tabname } _steps " , label = " Sampling steps " , value = 20 )
2023-01-01 01:19:10 +03:00
sampler_index = gr . Radio ( label = ' Sampling method ' , elem_id = f " { tabname } _sampling " , choices = [ x . name for x in choices ] , value = choices [ 0 ] . name , type = " index " )
return steps , sampler_index
2022-10-29 08:28:48 +03:00
2022-10-16 07:42:52 +03:00
2023-01-03 10:39:21 +03:00
def ordered_ui_categories ( ) :
2023-01-15 23:32:38 +03:00
user_order = { x . strip ( ) : i * 2 + 1 for i , x in enumerate ( shared . opts . ui_reorder . split ( " , " ) ) }
2023-01-03 10:39:21 +03:00
2023-01-15 23:32:38 +03:00
for i , category in sorted ( enumerate ( shared . ui_reorder_categories ) , key = lambda x : user_order . get ( x [ 1 ] , x [ 0 ] * 2 + 0 ) ) :
2023-01-03 10:39:21 +03:00
yield category
2023-01-19 18:58:08 +03:00
def get_value_for_setting ( key ) :
value = getattr ( opts , key )
info = opts . data_labels [ key ]
args = info . component_args ( ) if callable ( info . component_args ) else info . component_args or { }
args = { k : v for k , v in args . items ( ) if k not in { ' precision ' } }
return gr . update ( value = value , * * args )
2022-11-28 09:00:10 +03:00
def create_ui ( ) :
2022-10-21 16:10:51 +03:00
import modules . img2img
import modules . txt2img
2022-10-16 07:42:52 +03:00
2022-11-02 07:26:31 +03:00
reload_javascript ( )
2022-10-31 17:36:45 +03:00
parameters_copypaste . reset ( )
2022-10-16 07:42:52 +03:00
2022-11-19 19:10:17 +03:00
modules . scripts . scripts_current = modules . scripts . scripts_txt2img
modules . scripts . scripts_txt2img . initialize_scripts ( is_img2img = False )
2022-09-03 12:08:45 +03:00
with gr . Blocks ( analytics_enabled = False ) as txt2img_interface :
2023-01-21 08:36:07 +03:00
txt2img_prompt , txt2img_prompt_styles , txt2img_negative_prompt , submit , _ , _ , txt2img_prompt_style_apply , txt2img_save_style , txt2img_paste , extra_networks_button , token_counter , token_button , negative_token_counter , negative_token_button = create_toprow ( is_img2img = False )
2022-10-20 05:23:57 +03:00
2022-09-19 09:02:10 +03:00
dummy_component = gr . Label ( visible = False )
2023-01-18 23:04:24 +03:00
txt_prompt_img = gr . File ( label = " " , elem_id = " txt2img_prompt_image " , file_count = " single " , type = " binary " , visible = False )
2022-09-03 12:08:45 +03:00
2023-01-21 08:36:07 +03:00
with FormRow ( variant = ' compact ' , elem_id = " txt2img_extra_networks " , visible = False ) as extra_networks :
from modules import ui_extra_networks
extra_networks_ui = ui_extra_networks . create_ui ( extra_networks , extra_networks_button , ' txt2img ' )
2022-09-03 12:08:45 +03:00
with gr . Row ( ) . style ( equal_height = False ) :
2023-01-14 13:38:10 +03:00
with gr . Column ( variant = ' compact ' , elem_id = " txt2img_settings " ) :
2023-01-03 10:39:21 +03:00
for category in ordered_ui_categories ( ) :
if category == " sampler " :
steps , sampler_index = create_sampler_and_steps_selection ( samplers , " txt2img " )
2022-09-19 16:42:56 +03:00
2023-01-03 10:39:21 +03:00
elif category == " dimensions " :
with FormRow ( ) :
with gr . Column ( elem_id = " txt2img_column_size " , scale = 4 ) :
width = gr . Slider ( minimum = 64 , maximum = 2048 , step = 8 , label = " Width " , value = 512 , elem_id = " txt2img_width " )
height = gr . Slider ( minimum = 64 , maximum = 2048 , step = 8 , label = " Height " , value = 512 , elem_id = " txt2img_height " )
if opts . dimensions_and_batch_together :
with gr . Column ( elem_id = " txt2img_column_batch " ) :
batch_count = gr . Slider ( minimum = 1 , step = 1 , label = ' Batch count ' , value = 1 , elem_id = " txt2img_batch_count " )
batch_size = gr . Slider ( minimum = 1 , maximum = 8 , step = 1 , label = ' Batch size ' , value = 1 , elem_id = " txt2img_batch_size " )
elif category == " cfg " :
cfg_scale = gr . Slider ( minimum = 1.0 , maximum = 30.0 , step = 0.5 , label = ' CFG Scale ' , value = 7.0 , elem_id = " txt2img_cfg_scale " )
elif category == " seed " :
seed , reuse_seed , subseed , reuse_subseed , subseed_strength , seed_resize_from_h , seed_resize_from_w , seed_checkbox = create_seed_inputs ( ' txt2img ' )
elif category == " checkboxes " :
2023-01-19 00:44:51 +03:00
with FormRow ( elem_id = " txt2img_checkboxes " , variant = " compact " ) :
2023-01-03 10:39:21 +03:00
restore_faces = gr . Checkbox ( label = ' Restore faces ' , value = False , visible = len ( shared . face_restorers ) > 1 , elem_id = " txt2img_restore_faces " )
tiling = gr . Checkbox ( label = ' Tiling ' , value = False , elem_id = " txt2img_tiling " )
enable_hr = gr . Checkbox ( label = ' Hires. fix ' , value = False , elem_id = " txt2img_enable_hr " )
2023-01-07 09:56:37 +03:00
hr_final_resolution = FormHTML ( value = " " , elem_id = " txtimg_hr_finalres " , label = " Upscaled resolution " , interactive = False )
2023-01-03 10:39:21 +03:00
elif category == " hires_fix " :
2023-01-04 22:04:40 +03:00
with FormGroup ( visible = False , elem_id = " txt2img_hires_fix " ) as hr_options :
2023-01-19 00:44:51 +03:00
with FormRow ( elem_id = " txt2img_hires_fix_row1 " , variant = " compact " ) :
2023-01-04 22:04:40 +03:00
hr_upscaler = gr . Dropdown ( label = " Upscaler " , elem_id = " txt2img_hr_upscaler " , choices = [ * shared . latent_upscale_modes , * [ x . name for x in shared . sd_upscalers ] ] , value = shared . latent_upscale_default_mode )
hr_second_pass_steps = gr . Slider ( minimum = 0 , maximum = 150 , step = 1 , label = ' Hires steps ' , value = 0 , elem_id = " txt2img_hires_steps " )
denoising_strength = gr . Slider ( minimum = 0.0 , maximum = 1.0 , step = 0.01 , label = ' Denoising strength ' , value = 0.7 , elem_id = " txt2img_denoising_strength " )
2023-01-19 00:44:51 +03:00
with FormRow ( elem_id = " txt2img_hires_fix_row2 " , variant = " compact " ) :
2023-01-04 22:04:40 +03:00
hr_scale = gr . Slider ( minimum = 1.0 , maximum = 4.0 , step = 0.05 , label = " Upscale by " , value = 2.0 , elem_id = " txt2img_hr_scale " )
hr_resize_x = gr . Slider ( minimum = 0 , maximum = 2048 , step = 8 , label = " Resize width to " , value = 0 , elem_id = " txt2img_hr_resize_x " )
hr_resize_y = gr . Slider ( minimum = 0 , maximum = 2048 , step = 8 , label = " Resize height to " , value = 0 , elem_id = " txt2img_hr_resize_y " )
2023-01-03 10:39:21 +03:00
elif category == " batch " :
if not opts . dimensions_and_batch_together :
with FormRow ( elem_id = " txt2img_column_batch " ) :
batch_count = gr . Slider ( minimum = 1 , step = 1 , label = ' Batch count ' , value = 1 , elem_id = " txt2img_batch_count " )
batch_size = gr . Slider ( minimum = 1 , maximum = 8 , step = 1 , label = ' Batch size ' , value = 1 , elem_id = " txt2img_batch_size " )
elif category == " scripts " :
with FormGroup ( elem_id = " txt2img_script_container " ) :
custom_inputs = modules . scripts . scripts_txt2img . setup_ui ( )
2022-09-03 12:08:45 +03:00
2023-01-07 09:56:37 +03:00
hr_resolution_preview_inputs = [ enable_hr , width , height , hr_scale , hr_resize_x , hr_resize_y ]
for input in hr_resolution_preview_inputs :
2023-01-09 14:57:47 +03:00
input . change (
fn = calc_resolution_hires ,
inputs = hr_resolution_preview_inputs ,
outputs = [ hr_final_resolution ] ,
show_progress = False ,
)
input . change (
None ,
_js = " onCalcResolutionHires " ,
inputs = hr_resolution_preview_inputs ,
outputs = [ ] ,
show_progress = False ,
)
2023-01-07 08:53:53 +03:00
2022-12-31 23:40:55 +03:00
txt2img_gallery , generation_info , html_info , html_log = create_output_panel ( " txt2img " , opts . outdir_txt2img_samples )
2022-10-29 10:56:19 +03:00
parameters_copypaste . bind_buttons ( { " txt2img " : txt2img_paste } , None , txt2img_prompt )
2022-09-03 12:08:45 +03:00
2022-09-19 09:02:10 +03:00
connect_reuse_seed ( seed , reuse_seed , generation_info , dummy_component , is_subseed = False )
connect_reuse_seed ( subseed , reuse_subseed , generation_info , dummy_component , is_subseed = True )
2022-09-16 22:20:56 +03:00
2022-09-03 12:08:45 +03:00
txt2img_args = dict (
2022-12-31 23:40:55 +03:00
fn = wrap_gradio_gpu_call ( modules . txt2img . txt2img , extra_outputs = [ None , ' ' , ' ' ] ) ,
2022-09-06 02:09:01 +03:00
_js = " submit " ,
2022-09-03 12:08:45 +03:00
inputs = [
2023-01-15 18:50:56 +03:00
dummy_component ,
2022-09-09 23:16:02 +03:00
txt2img_prompt ,
2022-09-11 17:35:12 +03:00
txt2img_negative_prompt ,
2023-01-14 14:56:39 +03:00
txt2img_prompt_styles ,
2022-09-03 12:08:45 +03:00
steps ,
sampler_index ,
2022-09-07 12:32:28 +03:00
restore_faces ,
2022-09-05 03:25:37 +03:00
tiling ,
2022-09-03 12:08:45 +03:00
batch_count ,
batch_size ,
cfg_scale ,
seed ,
2022-09-21 13:34:10 +03:00
subseed , subseed_strength , seed_resize_from_h , seed_resize_from_w , seed_checkbox ,
2022-09-03 12:08:45 +03:00
height ,
width ,
2022-09-19 16:42:56 +03:00
enable_hr ,
denoising_strength ,
2023-01-02 19:42:10 +03:00
hr_scale ,
hr_upscaler ,
2023-01-04 22:04:40 +03:00
hr_second_pass_steps ,
hr_resize_x ,
hr_resize_y ,
2022-09-03 17:21:15 +03:00
] + custom_inputs ,
2022-10-15 17:20:17 +03:00
2022-09-03 12:08:45 +03:00
outputs = [
txt2img_gallery ,
generation_info ,
2022-12-31 23:40:55 +03:00
html_info ,
html_log ,
2022-09-18 11:14:42 +03:00
] ,
show_progress = False ,
2022-09-03 12:08:45 +03:00
)
2022-09-09 23:16:02 +03:00
txt2img_prompt . submit ( * * txt2img_args )
2022-09-03 12:08:45 +03:00
submit . click ( * * txt2img_args )
2022-10-13 02:17:26 +03:00
txt_prompt_img . change (
fn = modules . images . image_data ,
inputs = [
txt_prompt_img
] ,
outputs = [
txt2img_prompt ,
txt_prompt_img
]
)
2022-09-19 16:42:56 +03:00
enable_hr . change (
fn = lambda x : gr_show ( x ) ,
inputs = [ enable_hr ] ,
outputs = [ hr_options ] ,
2023-01-07 09:56:37 +03:00
show_progress = False ,
2022-09-19 16:42:56 +03:00
)
2022-09-25 09:25:28 +03:00
txt2img_paste_fields = [
( txt2img_prompt , " Prompt " ) ,
( txt2img_negative_prompt , " Negative prompt " ) ,
( steps , " Steps " ) ,
( sampler_index , " Sampler " ) ,
( restore_faces , " Face restoration " ) ,
( cfg_scale , " CFG scale " ) ,
( seed , " Seed " ) ,
( width , " Size-1 " ) ,
( height , " Size-2 " ) ,
( batch_size , " Batch size " ) ,
( subseed , " Variation seed " ) ,
( subseed_strength , " Variation seed strength " ) ,
( seed_resize_from_w , " Seed resize from-1 " ) ,
( seed_resize_from_h , " Seed resize from-2 " ) ,
( denoising_strength , " Denoising strength " ) ,
( enable_hr , lambda d : " Denoising strength " in d ) ,
( hr_options , lambda d : gr . Row . update ( visible = " Denoising strength " in d ) ) ,
2023-01-02 19:42:10 +03:00
( hr_scale , " Hires upscale " ) ,
( hr_upscaler , " Hires upscaler " ) ,
2023-01-04 22:04:40 +03:00
( hr_second_pass_steps , " Hires steps " ) ,
( hr_resize_x , " Hires resize-1 " ) ,
( hr_resize_y , " Hires resize-2 " ) ,
2022-10-22 12:23:45 +03:00
* modules . scripts . scripts_txt2img . infotext_fields
2022-09-25 09:25:28 +03:00
]
2022-10-29 10:56:19 +03:00
parameters_copypaste . add_paste_fields ( " txt2img " , None , txt2img_paste_fields )
2022-10-14 20:31:49 +03:00
txt2img_preview_params = [
txt2img_prompt ,
txt2img_negative_prompt ,
steps ,
sampler_index ,
cfg_scale ,
seed ,
width ,
height ,
]
2022-11-28 09:00:10 +03:00
token_button . click ( fn = wrap_queued_call ( update_token_counter ) , inputs = [ txt2img_prompt , steps ] , outputs = [ token_counter ] )
2023-01-20 10:18:41 +03:00
negative_token_button . click ( fn = wrap_queued_call ( update_token_counter ) , inputs = [ txt2img_negative_prompt , steps ] , outputs = [ negative_token_counter ] )
2022-09-23 22:49:21 +03:00
2023-01-21 08:36:07 +03:00
ui_extra_networks . setup_ui ( extra_networks_ui , txt2img_gallery )
2022-11-19 19:10:17 +03:00
modules . scripts . scripts_current = modules . scripts . scripts_img2img
modules . scripts . scripts_img2img . initialize_scripts ( is_img2img = True )
2022-09-23 22:49:21 +03:00
2022-09-03 12:08:45 +03:00
with gr . Blocks ( analytics_enabled = False ) as img2img_interface :
2023-01-21 08:36:07 +03:00
img2img_prompt , img2img_prompt_styles , img2img_negative_prompt , submit , img2img_interrogate , img2img_deepbooru , img2img_prompt_style_apply , img2img_save_style , img2img_paste , extra_networks_button , token_counter , token_button , negative_token_counter , negative_token_button = create_toprow ( is_img2img = True )
2022-09-03 12:08:45 +03:00
2023-01-18 23:04:24 +03:00
img2img_prompt_img = gr . File ( label = " " , elem_id = " img2img_prompt_image " , file_count = " single " , type = " binary " , visible = False )
2022-09-22 12:11:48 +03:00
2023-01-21 08:36:07 +03:00
with FormRow ( variant = ' compact ' , elem_id = " img2img_extra_networks " , visible = False ) as extra_networks :
from modules import ui_extra_networks
extra_networks_ui_img2img = ui_extra_networks . create_ui ( extra_networks , extra_networks_button , ' img2img ' )
2023-01-03 09:04:29 +03:00
with FormRow ( ) . style ( equal_height = False ) :
2023-01-14 13:38:10 +03:00
with gr . Column ( variant = ' compact ' , elem_id = " img2img_settings " ) :
2023-01-14 22:43:01 +03:00
copy_image_buttons = [ ]
copy_image_destinations = { }
def add_copy_image_controls ( tab_name , elem ) :
with gr . Row ( variant = " compact " , elem_id = f " img2img_copy_to_ { tab_name } " ) :
gr . HTML ( " Copy image to: " , elem_id = f " img2img_label_copy_to_ { tab_name } " )
for title , name in zip ( [ ' img2img ' , ' sketch ' , ' inpaint ' , ' inpaint sketch ' ] , [ ' img2img ' , ' sketch ' , ' inpaint ' , ' inpaint_sketch ' ] ) :
if name == tab_name :
gr . Button ( title , interactive = False )
copy_image_destinations [ name ] = elem
continue
button = gr . Button ( title )
copy_image_buttons . append ( ( button , name , elem ) )
2023-01-11 20:33:24 +03:00
with gr . Tabs ( elem_id = " mode_img2img " ) :
with gr . TabItem ( ' img2img ' , id = ' img2img ' , elem_id = " img2img_img2img_tab " ) as tab_img2img :
init_img = gr . Image ( label = " Image for img2img " , elem_id = " img2img_image " , show_label = False , source = " upload " , interactive = True , type = " pil " , tool = " editor " , image_mode = " RGBA " ) . style ( height = 480 )
2023-01-14 22:43:01 +03:00
add_copy_image_controls ( ' img2img ' , init_img )
2022-09-09 19:43:16 +03:00
2023-01-11 20:33:24 +03:00
with gr . TabItem ( ' Sketch ' , id = ' img2img_sketch ' , elem_id = " img2img_img2img_sketch_tab " ) as tab_sketch :
sketch = gr . Image ( label = " Image for img2img " , elem_id = " img2img_sketch " , show_label = False , source = " upload " , interactive = True , type = " pil " , tool = " color-sketch " , image_mode = " RGBA " ) . style ( height = 480 )
2023-01-14 22:43:01 +03:00
add_copy_image_controls ( ' sketch ' , sketch )
2022-09-03 12:08:45 +03:00
2023-01-11 20:33:24 +03:00
with gr . TabItem ( ' Inpaint ' , id = ' inpaint ' , elem_id = " img2img_inpaint_tab " ) as tab_inpaint :
init_img_with_mask = gr . Image ( label = " Image for inpainting with mask " , show_label = False , elem_id = " img2maskimg " , source = " upload " , interactive = True , type = " pil " , tool = " sketch " , image_mode = " RGBA " ) . style ( height = 480 )
2023-01-14 22:43:01 +03:00
add_copy_image_controls ( ' inpaint ' , init_img_with_mask )
2022-09-22 12:11:48 +03:00
2023-01-11 20:33:24 +03:00
with gr . TabItem ( ' Inpaint sketch ' , id = ' inpaint_sketch ' , elem_id = " img2img_inpaint_sketch_tab " ) as tab_inpaint_color :
inpaint_color_sketch = gr . Image ( label = " Color sketch inpainting " , show_label = False , elem_id = " inpaint_sketch " , source = " upload " , interactive = True , type = " pil " , tool = " color-sketch " , image_mode = " RGBA " ) . style ( height = 480 )
inpaint_color_sketch_orig = gr . State ( None )
2023-01-14 22:43:01 +03:00
add_copy_image_controls ( ' inpaint_sketch ' , inpaint_color_sketch )
2022-09-03 12:08:45 +03:00
2023-01-11 20:33:24 +03:00
def update_orig ( image , state ) :
if image is not None :
same_size = state is not None and state . size == image . size
has_exact_match = np . any ( np . all ( np . array ( image ) == np . array ( state ) , axis = - 1 ) )
edited = same_size and has_exact_match
return image if not edited or state is None else state
2022-09-22 12:11:48 +03:00
2023-01-11 20:33:24 +03:00
inpaint_color_sketch . change ( update_orig , [ inpaint_color_sketch , inpaint_color_sketch_orig ] , inpaint_color_sketch_orig )
2022-09-22 12:11:48 +03:00
2023-01-11 20:33:24 +03:00
with gr . TabItem ( ' Inpaint upload ' , id = ' inpaint_upload ' , elem_id = " img2img_inpaint_upload_tab " ) as tab_inpaint_upload :
init_img_inpaint = gr . Image ( label = " Image for img2img " , show_label = False , source = " upload " , interactive = True , type = " pil " , elem_id = " img_inpaint_base " )
init_mask_inpaint = gr . Image ( label = " Mask " , source = " upload " , interactive = True , type = " pil " , elem_id = " img_inpaint_mask " )
2022-09-22 12:11:48 +03:00
2023-01-11 20:33:24 +03:00
with gr . TabItem ( ' Batch ' , id = ' batch ' , elem_id = " img2img_batch_tab " ) as tab_batch :
2022-09-24 16:29:20 +03:00
hidden = ' <br>Disabled when launched with --hide-ui-dir-config. ' if shared . cmd_opts . hide_ui_dir_config else ' '
2023-01-14 22:43:01 +03:00
gr . HTML ( f " <p style= ' padding-bottom: 1em; ' class= \" text-gray-500 \" >Process images in a directory on the same machine where the server is running.<br>Use an empty output directory to save pictures normally instead of writing to the output directory. { hidden } </p> " )
2023-01-01 16:51:12 +03:00
img2img_batch_input_dir = gr . Textbox ( label = " Input directory " , * * shared . hide_dirs , elem_id = " img2img_batch_input_dir " )
img2img_batch_output_dir = gr . Textbox ( label = " Output directory " , * * shared . hide_dirs , elem_id = " img2img_batch_output_dir " )
2022-09-03 12:08:45 +03:00
2023-01-14 22:43:01 +03:00
def copy_image ( img ) :
if isinstance ( img , dict ) and ' image ' in img :
return img [ ' image ' ]
return img
for button , name , elem in copy_image_buttons :
button . click (
fn = copy_image ,
inputs = [ elem ] ,
outputs = [ copy_image_destinations [ name ] ] ,
)
button . click (
fn = lambda : None ,
_js = " switch_to_ " + name . replace ( " " , " _ " ) ,
inputs = [ ] ,
outputs = [ ] ,
)
2023-01-03 09:04:29 +03:00
with FormRow ( ) :
resize_mode = gr . Radio ( label = " Resize mode " , elem_id = " resize_mode " , choices = [ " Just resize " , " Crop and resize " , " Resize and fill " , " Just resize (latent upscale) " ] , type = " index " , value = " Just resize " )
2022-09-22 12:11:48 +03:00
2023-01-03 10:39:21 +03:00
for category in ordered_ui_categories ( ) :
if category == " sampler " :
steps , sampler_index = create_sampler_and_steps_selection ( samplers_for_img2img , " img2img " )
2022-09-03 21:02:38 +03:00
2023-01-03 10:39:21 +03:00
elif category == " dimensions " :
with FormRow ( ) :
with gr . Column ( elem_id = " img2img_column_size " , scale = 4 ) :
width = gr . Slider ( minimum = 64 , maximum = 2048 , step = 8 , label = " Width " , value = 512 , elem_id = " img2img_width " )
height = gr . Slider ( minimum = 64 , maximum = 2048 , step = 8 , label = " Height " , value = 512 , elem_id = " img2img_height " )
2022-09-25 08:40:29 +03:00
2023-01-03 10:39:21 +03:00
if opts . dimensions_and_batch_together :
with gr . Column ( elem_id = " img2img_column_batch " ) :
batch_count = gr . Slider ( minimum = 1 , step = 1 , label = ' Batch count ' , value = 1 , elem_id = " img2img_batch_count " )
batch_size = gr . Slider ( minimum = 1 , maximum = 8 , step = 1 , label = ' Batch size ' , value = 1 , elem_id = " img2img_batch_size " )
2022-09-03 12:08:45 +03:00
2023-01-03 10:39:21 +03:00
elif category == " cfg " :
with FormGroup ( ) :
cfg_scale = gr . Slider ( minimum = 1.0 , maximum = 30.0 , step = 0.5 , label = ' CFG Scale ' , value = 7.0 , elem_id = " img2img_cfg_scale " )
denoising_strength = gr . Slider ( minimum = 0.0 , maximum = 1.0 , step = 0.01 , label = ' Denoising strength ' , value = 0.75 , elem_id = " img2img_denoising_strength " )
2022-09-03 12:08:45 +03:00
2023-01-03 10:39:21 +03:00
elif category == " seed " :
seed , reuse_seed , subseed , reuse_subseed , subseed_strength , seed_resize_from_h , seed_resize_from_w , seed_checkbox = create_seed_inputs ( ' img2img ' )
2022-09-03 12:08:45 +03:00
2023-01-03 10:39:21 +03:00
elif category == " checkboxes " :
with FormRow ( elem_id = " img2img_checkboxes " ) :
restore_faces = gr . Checkbox ( label = ' Restore faces ' , value = False , visible = len ( shared . face_restorers ) > 1 , elem_id = " img2img_restore_faces " )
tiling = gr . Checkbox ( label = ' Tiling ' , value = False , elem_id = " img2img_tiling " )
2022-09-03 12:08:45 +03:00
2023-01-03 10:39:21 +03:00
elif category == " batch " :
if not opts . dimensions_and_batch_together :
with FormRow ( elem_id = " img2img_column_batch " ) :
batch_count = gr . Slider ( minimum = 1 , step = 1 , label = ' Batch count ' , value = 1 , elem_id = " img2img_batch_count " )
batch_size = gr . Slider ( minimum = 1 , maximum = 8 , step = 1 , label = ' Batch size ' , value = 1 , elem_id = " img2img_batch_size " )
2022-09-03 17:21:15 +03:00
2023-01-03 10:39:21 +03:00
elif category == " scripts " :
with FormGroup ( elem_id = " img2img_script_container " ) :
custom_inputs = modules . scripts . scripts_img2img . setup_ui ( )
2022-09-17 12:38:15 +03:00
2023-01-15 23:32:38 +03:00
elif category == " inpaint " :
2023-01-15 02:26:45 +03:00
with FormGroup ( elem_id = " inpaint_controls " , visible = False ) as inpaint_controls :
with FormRow ( ) :
mask_blur = gr . Slider ( label = ' Mask blur ' , minimum = 0 , maximum = 64 , step = 1 , value = 4 , elem_id = " img2img_mask_blur " )
mask_alpha = gr . Slider ( label = " Mask transparency " , visible = False , elem_id = " img2img_mask_alpha " )
with FormRow ( ) :
inpainting_mask_invert = gr . Radio ( label = ' Mask mode ' , choices = [ ' Inpaint masked ' , ' Inpaint not masked ' ] , value = ' Inpaint masked ' , type = " index " , elem_id = " img2img_mask_mode " )
with FormRow ( ) :
inpainting_fill = gr . Radio ( label = ' Masked content ' , choices = [ ' fill ' , ' original ' , ' latent noise ' , ' latent nothing ' ] , value = ' original ' , type = " index " , elem_id = " img2img_inpainting_fill " )
with FormRow ( ) :
with gr . Column ( ) :
inpaint_full_res = gr . Radio ( label = " Inpaint area " , choices = [ " Whole picture " , " Only masked " ] , type = " index " , value = " Whole picture " , elem_id = " img2img_inpaint_full_res " )
with gr . Column ( scale = 4 ) :
inpaint_full_res_padding = gr . Slider ( label = ' Only masked padding, pixels ' , minimum = 0 , maximum = 256 , step = 4 , value = 32 , elem_id = " img2img_inpaint_full_res_padding " )
def select_img2img_tab ( tab ) :
return gr . update ( visible = tab in [ 2 , 3 , 4 ] ) , gr . update ( visible = tab == 3 ) ,
for i , elem in enumerate ( [ tab_img2img , tab_sketch , tab_inpaint , tab_inpaint_color , tab_inpaint_upload , tab_batch ] ) :
elem . select (
fn = lambda tab = i : select_img2img_tab ( tab ) ,
inputs = [ ] ,
outputs = [ inpaint_controls , mask_alpha ] ,
)
2022-12-31 23:40:55 +03:00
img2img_gallery , generation_info , html_info , html_log = create_output_panel ( " img2img " , opts . outdir_img2img_samples )
2022-10-29 10:56:19 +03:00
parameters_copypaste . bind_buttons ( { " img2img " : img2img_paste } , None , img2img_prompt )
2022-09-03 12:08:45 +03:00
2022-09-19 09:02:10 +03:00
connect_reuse_seed ( seed , reuse_seed , generation_info , dummy_component , is_subseed = False )
connect_reuse_seed ( subseed , reuse_subseed , generation_info , dummy_component , is_subseed = True )
2022-09-16 22:20:56 +03:00
2022-10-13 02:17:26 +03:00
img2img_prompt_img . change (
fn = modules . images . image_data ,
inputs = [
2022-10-14 18:15:03 +03:00
img2img_prompt_img
2022-10-13 02:17:26 +03:00
] ,
outputs = [
img2img_prompt ,
img2img_prompt_img
]
)
2022-09-03 12:08:45 +03:00
img2img_args = dict (
2022-12-31 23:40:55 +03:00
fn = wrap_gradio_gpu_call ( modules . img2img . img2img , extra_outputs = [ None , ' ' , ' ' ] ) ,
2022-09-22 12:11:48 +03:00
_js = " submit_img2img " ,
2022-09-03 12:08:45 +03:00
inputs = [
2023-01-15 18:50:56 +03:00
dummy_component ,
2022-09-22 12:11:48 +03:00
dummy_component ,
2022-09-09 23:16:02 +03:00
img2img_prompt ,
2022-09-11 17:35:12 +03:00
img2img_negative_prompt ,
2023-01-14 14:56:39 +03:00
img2img_prompt_styles ,
2022-09-03 12:08:45 +03:00
init_img ,
2023-01-11 20:33:24 +03:00
sketch ,
2022-09-03 12:08:45 +03:00
init_img_with_mask ,
2023-01-11 20:33:24 +03:00
inpaint_color_sketch ,
inpaint_color_sketch_orig ,
2022-09-22 12:11:48 +03:00
init_img_inpaint ,
init_mask_inpaint ,
2022-09-03 12:08:45 +03:00
steps ,
sampler_index ,
mask_blur ,
2022-11-09 06:06:29 +03:00
mask_alpha ,
2022-09-03 12:08:45 +03:00
inpainting_fill ,
2022-09-07 12:32:28 +03:00
restore_faces ,
2022-09-05 03:25:37 +03:00
tiling ,
2022-09-03 12:08:45 +03:00
batch_count ,
batch_size ,
cfg_scale ,
denoising_strength ,
seed ,
2022-09-21 13:34:10 +03:00
subseed , subseed_strength , seed_resize_from_h , seed_resize_from_w , seed_checkbox ,
2022-09-03 12:08:45 +03:00
height ,
width ,
resize_mode ,
inpaint_full_res ,
2022-09-22 12:11:48 +03:00
inpaint_full_res_padding ,
2022-09-03 21:02:38 +03:00
inpainting_mask_invert ,
2022-09-22 12:11:48 +03:00
img2img_batch_input_dir ,
img2img_batch_output_dir ,
2022-09-03 17:21:15 +03:00
] + custom_inputs ,
2022-09-03 12:08:45 +03:00
outputs = [
img2img_gallery ,
generation_info ,
2022-12-31 23:40:55 +03:00
html_info ,
html_log ,
2022-09-18 11:14:42 +03:00
] ,
show_progress = False ,
2022-09-03 12:08:45 +03:00
)
2023-01-18 20:16:52 +03:00
interrogate_args = dict (
_js = " get_img2img_tab_index " ,
inputs = [
dummy_component ,
img2img_batch_input_dir ,
img2img_batch_output_dir ,
init_img ,
sketch ,
init_img_with_mask ,
inpaint_color_sketch ,
init_img_inpaint ,
] ,
outputs = [ img2img_prompt , dummy_component ] ,
)
2022-09-09 23:16:02 +03:00
img2img_prompt . submit ( * * img2img_args )
2022-09-03 12:08:45 +03:00
submit . click ( * * img2img_args )
2022-09-11 18:48:36 +03:00
img2img_interrogate . click (
2023-01-21 09:14:27 +03:00
fn = lambda * args : process_interrogate ( interrogate , * args ) ,
2023-01-18 20:16:52 +03:00
* * interrogate_args ,
2022-09-11 18:48:36 +03:00
)
2022-11-26 16:10:46 +03:00
img2img_deepbooru . click (
2023-01-21 09:14:27 +03:00
fn = lambda * args : process_interrogate ( interrogate_deepbooru , * args ) ,
2023-01-18 20:16:52 +03:00
* * interrogate_args ,
2022-09-14 17:56:21 +03:00
)
prompts = [ ( txt2img_prompt , txt2img_negative_prompt ) , ( img2img_prompt , img2img_negative_prompt ) ]
2023-01-14 14:56:39 +03:00
style_dropdowns = [ txt2img_prompt_styles , img2img_prompt_styles ]
2022-09-30 19:12:44 +03:00
style_js_funcs = [ " update_txt2img_tokens " , " update_img2img_tokens " ]
2022-09-14 17:56:21 +03:00
for button , ( prompt , negative_prompt ) in zip ( [ txt2img_save_style , img2img_save_style ] , prompts ) :
2022-09-09 23:16:02 +03:00
button . click (
fn = add_style ,
_js = " ask_for_style_name " ,
2022-09-11 17:35:12 +03:00
# Have to pass empty dummy component here, because the JavaScript and Python function have to accept
# the same number of parameters, but we only know the style-name after the JavaScript prompt
inputs = [ dummy_component , prompt , negative_prompt ] ,
2023-01-14 14:56:39 +03:00
outputs = [ txt2img_prompt_styles , img2img_prompt_styles ] ,
2022-09-14 17:56:21 +03:00
)
2023-01-14 14:56:39 +03:00
for button , ( prompt , negative_prompt ) , styles , js_func in zip ( [ txt2img_prompt_style_apply , img2img_prompt_style_apply ] , prompts , style_dropdowns , style_js_funcs ) :
2022-09-14 17:56:21 +03:00
button . click (
fn = apply_styles ,
2022-09-29 21:40:47 +03:00
_js = js_func ,
2023-01-14 14:56:39 +03:00
inputs = [ prompt , negative_prompt , styles ] ,
outputs = [ prompt , negative_prompt , styles ] ,
2022-09-09 23:16:02 +03:00
)
2022-10-27 08:36:11 +03:00
token_button . click ( fn = update_token_counter , inputs = [ img2img_prompt , steps ] , outputs = [ token_counter ] )
2023-01-20 10:18:41 +03:00
negative_token_button . click ( fn = wrap_queued_call ( update_token_counter ) , inputs = [ txt2img_negative_prompt , steps ] , outputs = [ negative_token_counter ] )
2022-10-27 08:36:11 +03:00
2023-01-21 08:36:07 +03:00
ui_extra_networks . setup_ui ( extra_networks_ui_img2img , img2img_gallery )
2022-09-25 09:25:28 +03:00
img2img_paste_fields = [
( img2img_prompt , " Prompt " ) ,
( img2img_negative_prompt , " Negative prompt " ) ,
( steps , " Steps " ) ,
( sampler_index , " Sampler " ) ,
( restore_faces , " Face restoration " ) ,
( cfg_scale , " CFG scale " ) ,
( seed , " Seed " ) ,
( width , " Size-1 " ) ,
( height , " Size-2 " ) ,
( batch_size , " Batch size " ) ,
( subseed , " Variation seed " ) ,
( subseed_strength , " Variation seed strength " ) ,
( seed_resize_from_w , " Seed resize from-1 " ) ,
( seed_resize_from_h , " Seed resize from-2 " ) ,
( denoising_strength , " Denoising strength " ) ,
2022-11-27 16:35:35 +03:00
( mask_blur , " Mask blur " ) ,
2022-10-22 12:23:45 +03:00
* modules . scripts . scripts_img2img . infotext_fields
2022-09-25 09:25:28 +03:00
]
2022-10-29 08:28:48 +03:00
parameters_copypaste . add_paste_fields ( " img2img " , init_img , img2img_paste_fields )
2022-10-27 08:36:11 +03:00
parameters_copypaste . add_paste_fields ( " inpaint " , init_img_with_mask , img2img_paste_fields )
2022-09-23 22:49:21 +03:00
2022-11-19 19:10:17 +03:00
modules . scripts . scripts_current = None
2022-09-23 22:49:21 +03:00
2022-09-03 12:08:45 +03:00
with gr . Blocks ( analytics_enabled = False ) as extras_interface :
2023-01-23 09:24:43 +03:00
ui_postprocessing . create_ui ( )
2022-09-03 12:08:45 +03:00
2022-09-23 22:49:21 +03:00
with gr . Blocks ( analytics_enabled = False ) as pnginfo_interface :
with gr . Row ( ) . style ( equal_height = False ) :
with gr . Column ( variant = ' panel ' ) :
image = gr . Image ( elem_id = " pnginfo_image " , label = " Source " , source = " upload " , interactive = True , type = " pil " )
with gr . Column ( variant = ' panel ' ) :
html = gr . HTML ( )
2023-01-01 16:51:12 +03:00
generation_info = gr . Textbox ( visible = False , elem_id = " pnginfo_generation_info " )
2022-09-23 22:49:21 +03:00
html2 = gr . HTML ( )
with gr . Row ( ) :
2022-10-27 08:36:11 +03:00
buttons = parameters_copypaste . create_buttons ( [ " txt2img " , " img2img " , " inpaint " , " extras " ] )
2022-10-29 08:28:48 +03:00
parameters_copypaste . bind_buttons ( buttons , image , generation_info )
2022-09-23 22:49:21 +03:00
image . change (
2022-10-02 15:03:39 +03:00
fn = wrap_gradio_call ( modules . extras . run_pnginfo ) ,
2022-09-23 22:49:21 +03:00
inputs = [ image ] ,
outputs = [ html , generation_info , html2 ] ,
)
2022-10-29 08:28:48 +03:00
2023-01-20 08:48:15 +03:00
def update_interp_description ( value ) :
interp_description_css = " <p style= ' margin-bottom: 2.5em ' > {} </p> "
interp_descriptions = {
" No interpolation " : interp_description_css . format ( " No interpolation will be used. Requires one model; A. Allows for format conversion and VAE baking. " ) ,
" Weighted sum " : interp_description_css . format ( " A weighted sum will be used for interpolation. Requires two models; A and B. The result is calculated as A * (1 - M) + B * M " ) ,
" Add difference " : interp_description_css . format ( " The difference between the last two models will be added to the first. Requires three models; A, B and C. The result is calculated as A + (B - C) * M " )
}
return interp_descriptions [ value ]
2022-11-06 14:39:41 +03:00
with gr . Blocks ( analytics_enabled = False ) as modelmerger_interface :
2022-09-26 02:22:12 +03:00
with gr . Row ( ) . style ( equal_height = False ) :
2023-01-14 13:38:10 +03:00
with gr . Column ( variant = ' compact ' ) :
2023-01-20 08:48:15 +03:00
interp_description = gr . HTML ( value = update_interp_description ( " Weighted sum " ) , elem_id = " modelmerger_interp_description " )
2022-10-02 15:03:39 +03:00
2023-01-19 10:39:51 +03:00
with FormRow ( elem_id = " modelmerger_models " ) :
2022-10-14 09:05:06 +03:00
primary_model_name = gr . Dropdown ( modules . sd_models . checkpoint_tiles ( ) , elem_id = " modelmerger_primary_model_name " , label = " Primary model (A) " )
2023-01-01 10:35:38 +03:00
create_refresh_button ( primary_model_name , modules . sd_models . list_models , lambda : { " choices " : modules . sd_models . checkpoint_tiles ( ) } , " refresh_checkpoint_A " )
2022-10-14 09:05:06 +03:00
secondary_model_name = gr . Dropdown ( modules . sd_models . checkpoint_tiles ( ) , elem_id = " modelmerger_secondary_model_name " , label = " Secondary model (B) " )
2023-01-01 10:35:38 +03:00
create_refresh_button ( secondary_model_name , modules . sd_models . list_models , lambda : { " choices " : modules . sd_models . checkpoint_tiles ( ) } , " refresh_checkpoint_B " )
2022-10-14 09:05:06 +03:00
tertiary_model_name = gr . Dropdown ( modules . sd_models . checkpoint_tiles ( ) , elem_id = " modelmerger_tertiary_model_name " , label = " Tertiary model (C) " )
2023-01-01 10:35:38 +03:00
create_refresh_button ( tertiary_model_name , modules . sd_models . list_models , lambda : { " choices " : modules . sd_models . checkpoint_tiles ( ) } , " refresh_checkpoint_C " )
2023-01-01 16:51:12 +03:00
custom_name = gr . Textbox ( label = " Custom Name (Optional) " , elem_id = " modelmerger_custom_name " )
interp_amount = gr . Slider ( minimum = 0.0 , maximum = 1.0 , step = 0.05 , label = ' Multiplier (M) - set to 0 to get model A ' , value = 0.3 , elem_id = " modelmerger_interp_amount " )
2023-01-19 10:39:51 +03:00
interp_method = gr . Radio ( choices = [ " No interpolation " , " Weighted sum " , " Add difference " ] , value = " Weighted sum " , label = " Interpolation Method " , elem_id = " modelmerger_interp_method " )
2023-01-20 08:48:15 +03:00
interp_method . change ( fn = update_interp_description , inputs = [ interp_method ] , outputs = [ interp_description ] )
2022-11-27 15:51:29 +03:00
2023-01-11 09:10:07 +03:00
with FormRow ( ) :
2023-01-01 16:51:12 +03:00
checkpoint_format = gr . Radio ( choices = [ " ckpt " , " safetensors " ] , value = " ckpt " , label = " Checkpoint format " , elem_id = " modelmerger_checkpoint_format " )
save_as_half = gr . Checkbox ( value = False , label = " Save as float16 " , elem_id = " modelmerger_save_as_half " )
2022-11-27 15:51:29 +03:00
2023-01-19 10:39:51 +03:00
with FormRow ( ) :
with gr . Column ( ) :
config_source = gr . Radio ( choices = [ " A, B or C " , " B " , " C " , " Don ' t " ] , value = " A, B or C " , label = " Copy config from " , type = " index " , elem_id = " modelmerger_config_method " )
with gr . Column ( ) :
with FormRow ( ) :
bake_in_vae = gr . Dropdown ( choices = [ " None " ] + list ( sd_vae . vae_dict ) , value = " None " , label = " Bake in VAE " , elem_id = " modelmerger_bake_in_vae " )
create_refresh_button ( bake_in_vae , sd_vae . refresh_vae_list , lambda : { " choices " : [ " None " ] + list ( sd_vae . vae_dict ) } , " modelmerger_refresh_bake_in_vae " )
2023-01-11 09:10:07 +03:00
2023-01-22 10:17:12 +03:00
with FormRow ( ) :
discard_weights = gr . Textbox ( value = " " , label = " Discard weights with matching name " , elem_id = " modelmerger_discard_weights " )
2023-01-14 13:38:10 +03:00
with gr . Row ( ) :
modelmerger_merge = gr . Button ( elem_id = " modelmerger_merge " , value = " Merge " , variant = ' primary ' )
2022-10-02 15:03:39 +03:00
2023-01-19 09:25:37 +03:00
with gr . Column ( variant = ' compact ' , elem_id = " modelmerger_results_container " ) :
with gr . Group ( elem_id = " modelmerger_results_panel " ) :
modelmerger_result = gr . HTML ( elem_id = " modelmerger_result " , show_label = False )
2022-09-26 02:22:12 +03:00
2022-11-06 14:39:41 +03:00
with gr . Blocks ( analytics_enabled = False ) as train_interface :
2022-10-02 15:03:39 +03:00
with gr . Row ( ) . style ( equal_height = False ) :
2022-10-12 11:05:57 +03:00
gr . HTML ( value = " <p style= ' margin-bottom: 0.7em ' >See <b><a href= \" https://github.com/AUTOMATIC1111/stable-diffusion-webui/wiki/Textual-Inversion \" >wiki</a></b> for detailed explanation.</p> " )
2022-10-02 22:41:21 +03:00
2023-01-21 23:40:13 +03:00
with gr . Row ( variant = " compact " ) . style ( equal_height = False ) :
2022-10-12 11:05:57 +03:00
with gr . Tabs ( elem_id = " train_tabs " ) :
2022-10-02 15:03:39 +03:00
2022-10-12 11:05:57 +03:00
with gr . Tab ( label = " Create embedding " ) :
2023-01-01 16:51:12 +03:00
new_embedding_name = gr . Textbox ( label = " Name " , elem_id = " train_new_embedding_name " )
initialization_text = gr . Textbox ( label = " Initialization text " , value = " * " , elem_id = " train_initialization_text " )
nvpt = gr . Slider ( label = " Number of vectors per token " , minimum = 1 , maximum = 75 , step = 1 , value = 1 , elem_id = " train_nvpt " )
overwrite_old_embedding = gr . Checkbox ( value = False , label = " Overwrite Old Embedding " , elem_id = " train_overwrite_old_embedding " )
2022-10-02 15:03:39 +03:00
with gr . Row ( ) :
with gr . Column ( scale = 3 ) :
gr . HTML ( value = " " )
with gr . Column ( ) :
2023-01-01 16:51:12 +03:00
create_embedding = gr . Button ( value = " Create embedding " , variant = ' primary ' , elem_id = " train_create_embedding " )
2022-10-02 15:03:39 +03:00
2022-10-12 11:05:57 +03:00
with gr . Tab ( label = " Create hypernetwork " ) :
2023-01-01 16:51:12 +03:00
new_hypernetwork_name = gr . Textbox ( label = " Name " , elem_id = " train_new_hypernetwork_name " )
new_hypernetwork_sizes = gr . CheckboxGroup ( label = " Modules " , value = [ " 768 " , " 320 " , " 640 " , " 1280 " ] , choices = [ " 768 " , " 1024 " , " 320 " , " 640 " , " 1280 " ] , elem_id = " train_new_hypernetwork_sizes " )
new_hypernetwork_layer_structure = gr . Textbox ( " 1, 2, 1 " , label = " Enter hypernetwork layer structure " , placeholder = " 1st and last digit must be 1. ex: ' 1, 2, 1 ' " , elem_id = " train_new_hypernetwork_layer_structure " )
new_hypernetwork_activation_func = gr . Dropdown ( value = " linear " , label = " Select activation function of hypernetwork. Recommended : Swish / Linear(none) " , choices = modules . hypernetworks . ui . keys , elem_id = " train_new_hypernetwork_activation_func " )
new_hypernetwork_initialization_option = gr . Dropdown ( value = " Normal " , label = " Select Layer weights initialization. Recommended: Kaiming for relu-like, Xavier for sigmoid-like, Normal otherwise " , choices = [ " Normal " , " KaimingUniform " , " KaimingNormal " , " XavierUniform " , " XavierNormal " ] , elem_id = " train_new_hypernetwork_initialization_option " )
new_hypernetwork_add_layer_norm = gr . Checkbox ( label = " Add layer normalization " , elem_id = " train_new_hypernetwork_add_layer_norm " )
new_hypernetwork_use_dropout = gr . Checkbox ( label = " Use dropout " , elem_id = " train_new_hypernetwork_use_dropout " )
2023-01-10 08:56:57 +03:00
new_hypernetwork_dropout_structure = gr . Textbox ( " 0, 0, 0 " , label = " Enter hypernetwork Dropout structure (or empty). Recommended : 0~0.35 incrementing sequence: 0, 0.05, 0.15 " , placeholder = " 1st and last digit must be 0 and values should be between 0 and 1. ex: ' 0, 0.01, 0 ' " )
2023-01-01 16:51:12 +03:00
overwrite_old_hypernetwork = gr . Checkbox ( value = False , label = " Overwrite Old Hypernetwork " , elem_id = " train_overwrite_old_hypernetwork " )
2022-10-07 23:22:22 +03:00
with gr . Row ( ) :
with gr . Column ( scale = 3 ) :
gr . HTML ( value = " " )
with gr . Column ( ) :
2023-01-01 16:51:12 +03:00
create_hypernetwork = gr . Button ( value = " Create hypernetwork " , variant = ' primary ' , elem_id = " train_create_hypernetwork " )
2022-10-02 15:03:39 +03:00
2022-10-12 11:05:57 +03:00
with gr . Tab ( label = " Preprocess images " ) :
2023-01-01 16:51:12 +03:00
process_src = gr . Textbox ( label = ' Source directory ' , elem_id = " train_process_src " )
process_dst = gr . Textbox ( label = ' Destination directory ' , elem_id = " train_process_dst " )
process_width = gr . Slider ( minimum = 64 , maximum = 2048 , step = 8 , label = " Width " , value = 512 , elem_id = " train_process_width " )
process_height = gr . Slider ( minimum = 64 , maximum = 2048 , step = 8 , label = " Height " , value = 512 , elem_id = " train_process_height " )
preprocess_txt_action = gr . Dropdown ( label = ' Existing Caption txt Action ' , value = " ignore " , choices = [ " ignore " , " copy " , " prepend " , " append " ] , elem_id = " train_preprocess_txt_action " )
2022-10-02 22:41:21 +03:00
with gr . Row ( ) :
2023-01-01 16:51:12 +03:00
process_flip = gr . Checkbox ( label = ' Create flipped copies ' , elem_id = " train_process_flip " )
process_split = gr . Checkbox ( label = ' Split oversized images ' , elem_id = " train_process_split " )
process_focal_crop = gr . Checkbox ( label = ' Auto focal point crop ' , elem_id = " train_process_focal_crop " )
2023-01-17 12:16:43 +03:00
process_multicrop = gr . Checkbox ( label = ' Auto-sized crop ' , elem_id = " train_process_multicrop " )
2023-01-01 16:51:12 +03:00
process_caption = gr . Checkbox ( label = ' Use BLIP for caption ' , elem_id = " train_process_caption " )
process_caption_deepbooru = gr . Checkbox ( label = ' Use deepbooru for caption ' , visible = True , elem_id = " train_process_caption_deepbooru " )
2022-10-02 22:41:21 +03:00
2022-10-20 16:56:45 +03:00
with gr . Row ( visible = False ) as process_split_extra_row :
2023-01-01 16:51:12 +03:00
process_split_threshold = gr . Slider ( label = ' Split image threshold ' , value = 0.5 , minimum = 0.0 , maximum = 1.0 , step = 0.05 , elem_id = " train_process_split_threshold " )
process_overlap_ratio = gr . Slider ( label = ' Split image overlap ratio ' , value = 0.2 , minimum = 0.0 , maximum = 0.9 , step = 0.05 , elem_id = " train_process_overlap_ratio " )
2022-10-20 16:56:45 +03:00
2022-10-26 01:22:29 +03:00
with gr . Row ( visible = False ) as process_focal_crop_row :
2023-01-01 16:51:12 +03:00
process_focal_crop_face_weight = gr . Slider ( label = ' Focal point face weight ' , value = 0.9 , minimum = 0.0 , maximum = 1.0 , step = 0.05 , elem_id = " train_process_focal_crop_face_weight " )
process_focal_crop_entropy_weight = gr . Slider ( label = ' Focal point entropy weight ' , value = 0.15 , minimum = 0.0 , maximum = 1.0 , step = 0.05 , elem_id = " train_process_focal_crop_entropy_weight " )
process_focal_crop_edges_weight = gr . Slider ( label = ' Focal point edges weight ' , value = 0.5 , minimum = 0.0 , maximum = 1.0 , step = 0.05 , elem_id = " train_process_focal_crop_edges_weight " )
process_focal_crop_debug = gr . Checkbox ( label = ' Create debug image ' , elem_id = " train_process_focal_crop_debug " )
2023-01-17 12:16:43 +03:00
with gr . Column ( visible = False ) as process_multicrop_col :
gr . Markdown ( ' Each image is center-cropped with an automatically chosen width and height. ' )
with gr . Row ( ) :
process_multicrop_mindim = gr . Slider ( minimum = 64 , maximum = 2048 , step = 8 , label = " Dimension lower bound " , value = 384 , elem_id = " train_process_multicrop_mindim " )
process_multicrop_maxdim = gr . Slider ( minimum = 64 , maximum = 2048 , step = 8 , label = " Dimension upper bound " , value = 768 , elem_id = " train_process_multicrop_maxdim " )
with gr . Row ( ) :
process_multicrop_minarea = gr . Slider ( minimum = 64 * 64 , maximum = 2048 * 2048 , step = 1 , label = " Area lower bound " , value = 64 * 64 , elem_id = " train_process_multicrop_minarea " )
process_multicrop_maxarea = gr . Slider ( minimum = 64 * 64 , maximum = 2048 * 2048 , step = 1 , label = " Area upper bound " , value = 640 * 640 , elem_id = " train_process_multicrop_maxarea " )
with gr . Row ( ) :
process_multicrop_objective = gr . Radio ( [ " Maximize area " , " Minimize error " ] , value = " Maximize area " , label = " Resizing objective " , elem_id = " train_process_multicrop_objective " )
process_multicrop_threshold = gr . Slider ( minimum = 0 , maximum = 1 , step = 0.01 , label = " Error threshold " , value = 0.1 , elem_id = " train_process_multicrop_threshold " )
2022-10-02 22:41:21 +03:00
with gr . Row ( ) :
with gr . Column ( scale = 3 ) :
gr . HTML ( value = " " )
with gr . Column ( ) :
2022-11-18 05:03:57 +03:00
with gr . Row ( ) :
2023-01-01 16:51:12 +03:00
interrupt_preprocessing = gr . Button ( " Interrupt " , elem_id = " train_interrupt_preprocessing " )
run_preprocess = gr . Button ( value = " Preprocess " , variant = ' primary ' , elem_id = " train_run_preprocess " )
2022-10-02 22:41:21 +03:00
2022-10-20 16:56:45 +03:00
process_split . change (
fn = lambda show : gr_show ( show ) ,
inputs = [ process_split ] ,
outputs = [ process_split_extra_row ] ,
)
2022-10-26 01:22:29 +03:00
process_focal_crop . change (
fn = lambda show : gr_show ( show ) ,
inputs = [ process_focal_crop ] ,
outputs = [ process_focal_crop_row ] ,
)
2023-01-17 12:16:43 +03:00
process_multicrop . change (
fn = lambda show : gr_show ( show ) ,
inputs = [ process_multicrop ] ,
outputs = [ process_multicrop_col ] ,
)
2023-01-09 23:35:40 +03:00
def get_textual_inversion_template_names ( ) :
return sorted ( [ x for x in textual_inversion . textual_inversion_templates ] )
2022-10-12 11:05:57 +03:00
with gr . Tab ( label = " Train " ) :
2022-10-19 22:33:18 +03:00
gr . HTML ( value = " <p style= ' margin-bottom: 0.7em ' >Train an embedding or Hypernetwork; you must specify a directory with a set of 1:1 ratio images <a href= \" https://github.com/AUTOMATIC1111/stable-diffusion-webui/wiki/Textual-Inversion \" style= \" font-weight:bold; \" >[wiki]</a></p> " )
2023-01-04 20:10:40 +03:00
with FormRow ( ) :
2022-10-17 21:15:32 +03:00
train_embedding_name = gr . Dropdown ( label = ' Embedding ' , elem_id = " train_embedding " , choices = sorted ( sd_hijack . model_hijack . embedding_db . word_embeddings . keys ( ) ) )
2022-10-16 07:42:52 +03:00
create_refresh_button ( train_embedding_name , sd_hijack . model_hijack . embedding_db . load_textual_inversion_embeddings , lambda : { " choices " : sorted ( sd_hijack . model_hijack . embedding_db . word_embeddings . keys ( ) ) } , " refresh_train_embedding_name " )
2023-01-04 20:10:40 +03:00
2022-10-17 21:15:32 +03:00
train_hypernetwork_name = gr . Dropdown ( label = ' Hypernetwork ' , elem_id = " train_hypernetwork " , choices = [ x for x in shared . hypernetworks . keys ( ) ] )
2022-10-16 07:42:52 +03:00
create_refresh_button ( train_hypernetwork_name , shared . reload_hypernetworks , lambda : { " choices " : sorted ( [ x for x in shared . hypernetworks . keys ( ) ] ) } , " refresh_train_hypernetwork_name " )
2023-01-04 20:10:40 +03:00
with FormRow ( ) :
2023-01-01 16:51:12 +03:00
embedding_learn_rate = gr . Textbox ( label = ' Embedding Learning rate ' , placeholder = " Embedding Learning rate " , value = " 0.005 " , elem_id = " train_embedding_learn_rate " )
hypernetwork_learn_rate = gr . Textbox ( label = ' Hypernetwork Learning rate ' , placeholder = " Hypernetwork Learning rate " , value = " 0.00001 " , elem_id = " train_hypernetwork_learn_rate " )
2023-01-04 19:56:35 +03:00
2023-01-04 20:10:40 +03:00
with FormRow ( ) :
2022-10-28 13:16:23 +03:00
clip_grad_mode = gr . Dropdown ( value = " disabled " , label = " Gradient Clipping " , choices = [ " disabled " , " value " , " norm " ] )
2022-10-31 09:49:24 +03:00
clip_grad_value = gr . Textbox ( placeholder = " Gradient clip value " , value = " 0.1 " , show_label = False )
2023-01-01 16:51:12 +03:00
2023-01-04 20:10:40 +03:00
with FormRow ( ) :
batch_size = gr . Number ( label = ' Batch size ' , value = 1 , precision = 0 , elem_id = " train_batch_size " )
gradient_step = gr . Number ( label = ' Gradient accumulation steps ' , value = 1 , precision = 0 , elem_id = " train_gradient_step " )
2023-01-01 16:51:12 +03:00
dataset_directory = gr . Textbox ( label = ' Dataset directory ' , placeholder = " Path to directory with input images " , elem_id = " train_dataset_directory " )
log_directory = gr . Textbox ( label = ' Log directory ' , placeholder = " Path to directory where to write outputs " , value = " textual_inversion " , elem_id = " train_log_directory " )
2023-01-09 23:35:40 +03:00
with FormRow ( ) :
template_file = gr . Dropdown ( label = ' Prompt template ' , value = " style_filewords.txt " , elem_id = " train_template_file " , choices = get_textual_inversion_template_names ( ) )
create_refresh_button ( template_file , textual_inversion . list_textual_inversion_templates , lambda : { " choices " : get_textual_inversion_template_names ( ) } , " refrsh_train_template_file " )
2023-01-01 16:51:12 +03:00
training_width = gr . Slider ( minimum = 64 , maximum = 2048 , step = 8 , label = " Width " , value = 512 , elem_id = " train_training_width " )
training_height = gr . Slider ( minimum = 64 , maximum = 2048 , step = 8 , label = " Height " , value = 512 , elem_id = " train_training_height " )
2023-01-09 22:52:23 +03:00
varsize = gr . Checkbox ( label = " Do not resize images " , value = False , elem_id = " train_varsize " )
2023-01-01 16:51:12 +03:00
steps = gr . Number ( label = ' Max steps ' , value = 100000 , precision = 0 , elem_id = " train_steps " )
2023-01-04 20:10:40 +03:00
with FormRow ( ) :
create_image_every = gr . Number ( label = ' Save an image to log directory every N steps, 0 to disable ' , value = 500 , precision = 0 , elem_id = " train_create_image_every " )
save_embedding_every = gr . Number ( label = ' Save a copy of embedding to log directory every N steps, 0 to disable ' , value = 500 , precision = 0 , elem_id = " train_save_embedding_every " )
2023-01-01 16:51:12 +03:00
save_image_with_stored_embedding = gr . Checkbox ( label = ' Save images with embedding in PNG chunks ' , value = True , elem_id = " train_save_image_with_stored_embedding " )
preview_from_txt2img = gr . Checkbox ( label = ' Read parameters (prompt, etc...) from txt2img tab when making previews ' , value = False , elem_id = " train_preview_from_txt2img " )
2023-01-04 20:10:40 +03:00
shuffle_tags = gr . Checkbox ( label = " Shuffle tags by ' , ' when creating prompts. " , value = False , elem_id = " train_shuffle_tags " )
tag_drop_out = gr . Slider ( minimum = 0 , maximum = 1 , step = 0.1 , label = " Drop out tags when creating prompts. " , value = 0 , elem_id = " train_tag_drop_out " )
latent_sampling_method = gr . Radio ( label = ' Choose latent sampling method ' , value = " once " , choices = [ ' once ' , ' deterministic ' , ' random ' ] , elem_id = " train_latent_sampling_method " )
2022-10-02 15:03:39 +03:00
with gr . Row ( ) :
2023-01-04 20:10:40 +03:00
train_embedding = gr . Button ( value = " Train Embedding " , variant = ' primary ' , elem_id = " train_train_embedding " )
2023-01-01 16:51:12 +03:00
interrupt_training = gr . Button ( value = " Interrupt " , elem_id = " train_interrupt_training " )
train_hypernetwork = gr . Button ( value = " Train Hypernetwork " , variant = ' primary ' , elem_id = " train_train_hypernetwork " )
2022-10-02 15:03:39 +03:00
2022-11-08 08:38:10 +03:00
params = script_callbacks . UiTrainTabParams ( txt2img_preview_params )
script_callbacks . ui_train_tabs_callback ( params )
2023-01-15 18:50:56 +03:00
with gr . Column ( elem_id = ' ti_gallery_container ' ) :
2022-10-02 15:03:39 +03:00
ti_output = gr . Text ( elem_id = " ti_output " , value = " " , show_label = False )
ti_gallery = gr . Gallery ( label = ' Output ' , show_label = False , elem_id = ' ti_gallery ' ) . style ( grid = 4 )
ti_progress = gr . HTML ( elem_id = " ti_progress " , value = " " )
ti_outcome = gr . HTML ( elem_id = " ti_error " , value = " " )
create_embedding . click (
fn = modules . textual_inversion . ui . create_embedding ,
inputs = [
new_embedding_name ,
2022-10-02 19:40:51 +03:00
initialization_text ,
2022-10-02 15:03:39 +03:00
nvpt ,
2022-10-19 22:33:18 +03:00
overwrite_old_embedding ,
2022-10-02 15:03:39 +03:00
] ,
outputs = [
train_embedding_name ,
ti_output ,
ti_outcome ,
]
)
2022-10-07 23:22:22 +03:00
create_hypernetwork . click (
2022-10-11 15:54:34 +03:00
fn = modules . hypernetworks . ui . create_hypernetwork ,
2022-10-07 23:22:22 +03:00
inputs = [
new_hypernetwork_name ,
2022-10-11 18:04:47 +03:00
new_hypernetwork_sizes ,
2022-10-20 02:27:16 +03:00
overwrite_old_hypernetwork ,
2022-10-19 17:30:33 +03:00
new_hypernetwork_layer_structure ,
2022-10-20 03:10:45 +03:00
new_hypernetwork_activation_func ,
2022-10-25 08:48:49 +03:00
new_hypernetwork_initialization_option ,
2022-10-19 17:30:33 +03:00
new_hypernetwork_add_layer_norm ,
2023-01-10 08:56:57 +03:00
new_hypernetwork_use_dropout ,
new_hypernetwork_dropout_structure
2022-10-07 23:22:22 +03:00
] ,
outputs = [
train_hypernetwork_name ,
ti_output ,
ti_outcome ,
]
)
2022-10-02 22:41:21 +03:00
run_preprocess . click (
fn = wrap_gradio_gpu_call ( modules . textual_inversion . ui . preprocess , extra_outputs = [ gr . update ( ) ] ) ,
_js = " start_training_textual_inversion " ,
inputs = [
2023-01-15 18:50:56 +03:00
dummy_component ,
2022-10-02 22:41:21 +03:00
process_src ,
process_dst ,
2022-10-10 16:35:35 +03:00
process_width ,
process_height ,
2022-10-20 02:48:07 +03:00
preprocess_txt_action ,
2022-10-02 22:41:21 +03:00
process_flip ,
process_split ,
process_caption ,
2022-10-20 16:56:45 +03:00
process_caption_deepbooru ,
process_split_threshold ,
process_overlap_ratio ,
2022-10-26 01:22:29 +03:00
process_focal_crop ,
process_focal_crop_face_weight ,
process_focal_crop_entropy_weight ,
process_focal_crop_edges_weight ,
process_focal_crop_debug ,
2023-01-17 12:16:43 +03:00
process_multicrop ,
process_multicrop_mindim ,
process_multicrop_maxdim ,
process_multicrop_minarea ,
process_multicrop_maxarea ,
process_multicrop_objective ,
process_multicrop_threshold ,
2022-10-02 22:41:21 +03:00
] ,
outputs = [
ti_output ,
ti_outcome ,
] ,
)
2022-10-02 15:03:39 +03:00
train_embedding . click (
fn = wrap_gradio_gpu_call ( modules . textual_inversion . ui . train_embedding , extra_outputs = [ gr . update ( ) ] ) ,
_js = " start_training_textual_inversion " ,
inputs = [
2023-01-15 18:50:56 +03:00
dummy_component ,
2022-10-02 15:03:39 +03:00
train_embedding_name ,
2022-10-20 02:19:40 +03:00
embedding_learn_rate ,
2022-10-15 09:24:59 +03:00
batch_size ,
2022-11-20 06:35:26 +03:00
gradient_step ,
2022-10-02 15:03:39 +03:00
dataset_directory ,
log_directory ,
2022-10-10 16:35:35 +03:00
training_width ,
training_height ,
2023-01-07 20:34:52 +03:00
varsize ,
2022-10-02 15:03:39 +03:00
steps ,
2022-10-28 06:31:27 +03:00
clip_grad_mode ,
clip_grad_value ,
2022-11-20 06:35:26 +03:00
shuffle_tags ,
tag_drop_out ,
latent_sampling_method ,
2022-10-02 15:03:39 +03:00
create_image_every ,
save_embedding_every ,
template_file ,
2022-10-09 07:40:57 +03:00
save_image_with_stored_embedding ,
2022-10-14 20:31:49 +03:00
preview_from_txt2img ,
* txt2img_preview_params ,
2022-10-02 15:03:39 +03:00
] ,
outputs = [
ti_output ,
ti_outcome ,
]
)
2022-10-07 23:22:22 +03:00
train_hypernetwork . click (
2022-10-11 15:54:34 +03:00
fn = wrap_gradio_gpu_call ( modules . hypernetworks . ui . train_hypernetwork , extra_outputs = [ gr . update ( ) ] ) ,
2022-10-07 23:22:22 +03:00
_js = " start_training_textual_inversion " ,
inputs = [
2023-01-15 18:50:56 +03:00
dummy_component ,
2022-10-07 23:22:22 +03:00
train_hypernetwork_name ,
2022-10-20 02:19:40 +03:00
hypernetwork_learn_rate ,
2022-10-15 09:24:59 +03:00
batch_size ,
2022-11-20 06:35:26 +03:00
gradient_step ,
2022-10-07 23:22:22 +03:00
dataset_directory ,
log_directory ,
2022-10-19 08:44:33 +03:00
training_width ,
training_height ,
2023-01-07 20:34:52 +03:00
varsize ,
2022-10-02 15:03:39 +03:00
steps ,
2022-10-28 05:44:56 +03:00
clip_grad_mode ,
clip_grad_value ,
2022-11-20 06:35:26 +03:00
shuffle_tags ,
tag_drop_out ,
latent_sampling_method ,
2022-10-02 15:03:39 +03:00
create_image_every ,
save_embedding_every ,
template_file ,
2022-10-14 20:31:49 +03:00
preview_from_txt2img ,
* txt2img_preview_params ,
2022-10-02 15:03:39 +03:00
] ,
outputs = [
ti_output ,
ti_outcome ,
]
)
interrupt_training . click (
fn = lambda : shared . state . interrupt ( ) ,
inputs = [ ] ,
outputs = [ ] ,
)
2022-11-18 05:03:57 +03:00
interrupt_preprocessing . click (
fn = lambda : shared . state . interrupt ( ) ,
inputs = [ ] ,
outputs = [ ] ,
)
2022-10-13 19:22:41 +03:00
def create_setting_component ( key , is_quicksettings = False ) :
2022-09-03 12:08:45 +03:00
def fun ( ) :
return opts . data [ key ] if key in opts . data else opts . data_labels [ key ] . default
info = opts . data_labels [ key ]
t = type ( info . default )
2022-09-11 23:00:42 +03:00
args = info . component_args ( ) if callable ( info . component_args ) else info . component_args
2022-09-03 12:08:45 +03:00
if info . component is not None :
2022-09-11 23:00:42 +03:00
comp = info . component
2022-09-03 12:08:45 +03:00
elif t == str :
2022-09-11 23:00:42 +03:00
comp = gr . Textbox
2022-09-03 12:08:45 +03:00
elif t == int :
2022-09-11 23:00:42 +03:00
comp = gr . Number
2022-09-03 12:08:45 +03:00
elif t == bool :
2022-09-11 23:00:42 +03:00
comp = gr . Checkbox
2022-09-03 12:08:45 +03:00
else :
raise Exception ( f ' bad options item type: { str ( t ) } for key { key } ' )
2022-10-17 21:15:32 +03:00
elem_id = " setting_ " + key
2022-10-13 19:22:41 +03:00
if info . refresh is not None :
if is_quicksettings :
2022-11-06 14:39:41 +03:00
res = comp ( label = info . label , value = fun ( ) , elem_id = elem_id , * * ( args or { } ) )
2022-10-17 21:15:32 +03:00
create_refresh_button ( res , info . refresh , info . component_args , " refresh_ " + key )
2022-10-13 19:22:41 +03:00
else :
2023-01-03 09:04:29 +03:00
with FormRow ( ) :
2022-11-06 14:39:41 +03:00
res = comp ( label = info . label , value = fun ( ) , elem_id = elem_id , * * ( args or { } ) )
2022-10-17 21:15:32 +03:00
create_refresh_button ( res , info . refresh , info . component_args , " refresh_ " + key )
2022-10-13 19:22:41 +03:00
else :
2022-11-06 14:39:41 +03:00
res = comp ( label = info . label , value = fun ( ) , elem_id = elem_id , * * ( args or { } ) )
2022-10-13 19:22:41 +03:00
return res
2022-09-03 12:08:45 +03:00
2022-09-10 11:10:00 +03:00
components = [ ]
2022-09-29 00:59:44 +03:00
component_dict = { }
2022-09-10 11:10:00 +03:00
2022-10-22 19:18:56 +03:00
script_callbacks . ui_settings_callback ( )
opts . reorder ( )
2022-09-03 12:08:45 +03:00
def run_settings ( * args ) :
2022-11-06 10:12:53 +03:00
changed = [ ]
2022-09-23 17:27:30 +03:00
for key , value , comp in zip ( opts . data_labels . keys ( ) , args , components ) :
2022-11-04 10:35:30 +03:00
assert comp == dummy_component or opts . same_type ( value , opts . data_labels [ key ] . default ) , f " Bad value for setting { key } : { value } ; expecting { type ( opts . data_labels [ key ] . default ) . __name__ } "
2022-09-03 12:08:45 +03:00
2022-09-10 11:10:00 +03:00
for key , value , comp in zip ( opts . data_labels . keys ( ) , args , components ) :
2022-10-09 22:24:07 +03:00
if comp == dummy_component :
continue
2022-11-19 15:15:24 +03:00
if opts . set ( key , value ) :
2022-11-06 10:12:53 +03:00
changed . append ( key )
2022-09-11 23:00:42 +03:00
2022-11-04 22:24:42 +03:00
try :
opts . save ( shared . config_filename )
except RuntimeError :
2022-11-06 10:12:53 +03:00
return opts . dumpjson ( ) , f ' { len ( changed ) } settings changed without save: { " , " . join ( changed ) } . '
2022-12-09 09:47:45 +03:00
return opts . dumpjson ( ) , f ' { len ( changed ) } settings changed { " : " if len ( changed ) > 0 else " " } { " , " . join ( changed ) } . '
2022-09-03 12:08:45 +03:00
2022-10-09 22:24:07 +03:00
def run_settings_single ( value , key ) :
if not opts . same_type ( value , opts . data_labels [ key ] . default ) :
return gr . update ( visible = True ) , opts . dumpjson ( )
2022-11-19 15:15:24 +03:00
if not opts . set ( key , value ) :
return gr . update ( value = getattr ( opts , key ) ) , opts . dumpjson ( )
2022-10-09 22:24:07 +03:00
opts . save ( shared . config_filename )
2023-01-19 18:58:08 +03:00
return get_value_for_setting ( key ) , opts . dumpjson ( )
2022-10-09 22:24:07 +03:00
2022-09-10 11:10:00 +03:00
with gr . Blocks ( analytics_enabled = False ) as settings_interface :
2023-01-03 07:20:20 +03:00
with gr . Row ( ) :
2023-01-03 20:23:17 +03:00
with gr . Column ( scale = 6 ) :
settings_submit = gr . Button ( value = " Apply settings " , variant = ' primary ' , elem_id = " settings_submit " )
with gr . Column ( ) :
restart_gradio = gr . Button ( value = ' Reload UI ' , variant = ' primary ' , elem_id = " settings_restart_gradio " )
2022-09-10 11:10:00 +03:00
2023-01-03 07:20:20 +03:00
result = gr . HTML ( elem_id = " settings_result " )
2022-09-10 11:10:00 +03:00
2022-10-13 16:07:18 +03:00
quicksettings_names = [ x . strip ( ) for x in opts . quicksettings . split ( " , " ) ]
2023-01-03 09:13:35 +03:00
quicksettings_names = { x : i for i , x in enumerate ( quicksettings_names ) if x != ' quicksettings ' }
2022-10-13 16:07:18 +03:00
2022-10-09 22:24:07 +03:00
quicksettings_list = [ ]
2022-09-22 21:32:44 +03:00
previous_section = None
2023-01-03 07:20:20 +03:00
current_tab = None
2023-01-14 14:56:39 +03:00
current_row = None
2023-01-03 07:20:20 +03:00
with gr . Tabs ( elem_id = " settings " ) :
2022-09-22 21:32:44 +03:00
for i , ( k , item ) in enumerate ( opts . data_labels . items ( ) ) :
2022-10-31 17:36:45 +03:00
section_must_be_skipped = item . section [ 0 ] is None
2022-09-22 19:26:26 +03:00
2022-10-31 17:36:45 +03:00
if previous_section != item . section and not section_must_be_skipped :
2023-01-03 07:20:20 +03:00
elem_id , text = item . section
2022-09-22 19:26:26 +03:00
2023-01-03 07:20:20 +03:00
if current_tab is not None :
2023-01-14 14:56:39 +03:00
current_row . __exit__ ( )
2023-01-03 07:20:20 +03:00
current_tab . __exit__ ( )
2022-09-10 11:10:00 +03:00
2023-01-14 14:56:39 +03:00
gr . Group ( )
2023-01-03 07:20:20 +03:00
current_tab = gr . TabItem ( elem_id = " settings_ {} " . format ( elem_id ) , label = text )
current_tab . __enter__ ( )
2023-01-14 14:56:39 +03:00
current_row = gr . Column ( variant = ' compact ' )
current_row . __enter__ ( )
2022-09-22 21:32:44 +03:00
previous_section = item . section
2022-10-22 22:05:22 +03:00
if k in quicksettings_names and not shared . cmd_opts . freeze_settings :
2022-10-09 22:24:07 +03:00
quicksettings_list . append ( ( i , k , item ) )
components . append ( dummy_component )
2022-10-31 17:36:45 +03:00
elif section_must_be_skipped :
components . append ( dummy_component )
2022-10-09 22:24:07 +03:00
else :
component = create_setting_component ( k )
component_dict [ k ] = component
components . append ( component )
2022-09-03 12:08:45 +03:00
2023-01-03 07:20:20 +03:00
if current_tab is not None :
2023-01-14 14:56:39 +03:00
current_row . __exit__ ( )
2023-01-03 07:20:20 +03:00
current_tab . __exit__ ( )
2022-10-17 21:15:32 +03:00
2023-01-03 07:20:20 +03:00
with gr . TabItem ( " Actions " ) :
request_notifications = gr . Button ( value = ' Request browser notifications ' , elem_id = " request_notifications " )
download_localization = gr . Button ( value = ' Download localization template ' , elem_id = " download_localization " )
reload_script_bodies = gr . Button ( value = ' Reload custom script bodies (No ui updates, No restart) ' , variant = ' secondary ' , elem_id = " settings_reload_script_bodies " )
2022-10-13 16:07:18 +03:00
2023-01-21 08:36:07 +03:00
with gr . TabItem ( " Licenses " ) :
gr . HTML ( shared . html ( " licenses.html " ) , elem_id = " licenses " )
2023-01-03 20:23:17 +03:00
2023-01-03 10:01:06 +03:00
gr . Button ( value = " Show all pages " , elem_id = " settings_show_all_pages " )
2022-10-13 16:07:18 +03:00
2022-09-19 04:41:57 +03:00
request_notifications . click (
fn = lambda : None ,
inputs = [ ] ,
outputs = [ ] ,
2022-09-22 13:15:33 +03:00
_js = ' function() {} '
2022-09-19 04:41:57 +03:00
)
2022-10-17 21:15:32 +03:00
download_localization . click (
fn = lambda : None ,
inputs = [ ] ,
outputs = [ ] ,
_js = ' download_localization '
)
2022-10-02 03:19:55 +03:00
def reload_scripts ( ) :
2022-10-02 21:26:06 +03:00
modules . scripts . reload_script_body_only ( )
2022-10-22 22:05:22 +03:00
reload_javascript ( ) # need to refresh the html page
2022-10-02 03:19:55 +03:00
reload_script_bodies . click (
fn = reload_scripts ,
inputs = [ ] ,
2022-11-02 09:47:53 +03:00
outputs = [ ]
2022-10-02 03:19:55 +03:00
)
2022-10-02 03:36:30 +03:00
def request_restart ( ) :
2022-10-05 06:43:05 +03:00
shared . state . interrupt ( )
2022-10-31 17:36:45 +03:00
shared . state . need_restart = True
2022-10-02 03:36:30 +03:00
restart_gradio . click (
fn = request_restart ,
2022-11-06 09:02:25 +03:00
_js = ' restart_reload ' ,
2022-10-02 03:36:30 +03:00
inputs = [ ] ,
outputs = [ ] ,
)
2022-10-10 04:26:52 +03:00
2022-09-03 12:08:45 +03:00
interfaces = [
2022-09-10 11:10:00 +03:00
( txt2img_interface , " txt2img " , " txt2img " ) ,
( img2img_interface , " img2img " , " img2img " ) ,
( extras_interface , " Extras " , " extras " ) ,
( pnginfo_interface , " PNG Info " , " pnginfo " ) ,
2022-09-26 02:22:12 +03:00
( modelmerger_interface , " Checkpoint Merger " , " modelmerger " ) ,
2022-10-12 11:05:57 +03:00
( train_interface , " Train " , " ti " ) ,
2022-09-03 12:08:45 +03:00
]
2022-10-22 13:34:49 +03:00
css = " "
for cssfile in modules . scripts . list_files_with_name ( " style.css " ) :
2022-10-22 14:28:56 +03:00
if not os . path . isfile ( cssfile ) :
continue
2022-10-22 13:34:49 +03:00
with open ( cssfile , " r " , encoding = " utf8 " ) as file :
css + = file . read ( ) + " \n "
2022-09-03 12:08:45 +03:00
2022-09-17 16:35:58 +03:00
if os . path . exists ( os . path . join ( script_path , " user.css " ) ) :
2022-09-17 16:28:19 +03:00
with open ( os . path . join ( script_path , " user.css " ) , " r " , encoding = " utf8 " ) as file :
2022-10-22 13:34:49 +03:00
css + = file . read ( ) + " \n "
2022-09-17 16:28:19 +03:00
2022-09-03 12:08:45 +03:00
if not cmd_opts . no_progressbar_hiding :
css + = css_hide_progressbar
2022-10-29 10:56:19 +03:00
interfaces + = script_callbacks . ui_tabs_callback ( )
interfaces + = [ ( settings_interface , " Settings " , " settings " ) ]
2022-10-31 17:36:45 +03:00
extensions_interface = ui_extensions . create_ui ( )
interfaces + = [ ( extensions_interface , " Extensions " , " extensions " ) ]
2022-09-10 11:10:00 +03:00
with gr . Blocks ( css = css , analytics_enabled = False , title = " Stable Diffusion " ) as demo :
2023-01-18 14:33:09 +03:00
with gr . Row ( elem_id = " quicksettings " , variant = " compact " ) :
2023-01-03 09:13:35 +03:00
for i , k , item in sorted ( quicksettings_list , key = lambda x : quicksettings_names . get ( x [ 1 ] , x [ 0 ] ) ) :
2022-10-13 19:22:41 +03:00
component = create_setting_component ( k , is_quicksettings = True )
2022-10-09 22:24:07 +03:00
component_dict [ k ] = component
2022-10-29 10:56:19 +03:00
parameters_copypaste . integrate_settings_paste_fields ( component_dict )
parameters_copypaste . run_bind ( )
2022-10-13 20:42:27 +03:00
with gr . Tabs ( elem_id = " tabs " ) as tabs :
2022-09-10 11:10:00 +03:00
for interface , label , ifid in interfaces :
2022-10-11 08:22:46 +03:00
with gr . TabItem ( label , id = ifid , elem_id = ' tab_ ' + ifid ) :
2022-09-10 11:10:00 +03:00
interface . render ( )
2022-10-10 04:26:52 +03:00
2022-09-26 23:57:31 +03:00
if os . path . exists ( os . path . join ( script_path , " notification.mp3 " ) ) :
audio_notification = gr . Audio ( interactive = False , value = os . path . join ( script_path , " notification.mp3 " ) , elem_id = " audio_notification " , visible = False )
2022-09-10 11:10:00 +03:00
2023-01-21 08:36:07 +03:00
footer = shared . html ( " footer.html " )
footer = footer . format ( versions = versions_html ( ) )
gr . HTML ( footer , elem_id = " footer " )
2023-01-03 20:23:17 +03:00
2022-09-19 17:16:04 +03:00
text_settings = gr . Textbox ( elem_id = " settings_json " , value = lambda : opts . dumpjson ( ) , visible = False )
2022-09-18 22:25:18 +03:00
settings_submit . click (
2022-11-04 10:35:30 +03:00
fn = wrap_gradio_call ( run_settings , extra_outputs = [ gr . update ( ) ] ) ,
2022-09-24 00:13:32 +03:00
inputs = components ,
2022-11-04 10:35:30 +03:00
outputs = [ text_settings , result ] ,
2022-09-18 22:25:18 +03:00
)
2022-10-09 22:24:07 +03:00
for i , k , item in quicksettings_list :
component = component_dict [ k ]
component . change (
fn = lambda value , k = k : run_settings_single ( value , key = k ) ,
inputs = [ component ] ,
outputs = [ component , text_settings ] ,
)
2022-11-06 14:39:41 +03:00
component_keys = [ k for k in opts . data_labels . keys ( ) if k in component_dict ]
def get_settings_values ( ) :
2023-01-19 18:07:37 +03:00
return [ get_value_for_setting ( key ) for key in component_keys ]
2022-11-06 14:39:41 +03:00
demo . load (
fn = get_settings_values ,
inputs = [ ] ,
outputs = [ component_dict [ k ] for k in component_keys ] ,
)
2022-09-29 02:50:34 +03:00
def modelmerger ( * args ) :
try :
2022-10-02 15:03:39 +03:00
results = modules . extras . run_modelmerger ( * args )
2022-09-29 02:50:34 +03:00
except Exception as e :
print ( " Error loading/saving model file: " , file = sys . stderr )
print ( traceback . format_exc ( ) , file = sys . stderr )
2022-10-02 15:03:39 +03:00
modules . sd_models . list_models ( ) # to remove the potentially missing models from the list
2023-01-19 09:25:37 +03:00
return [ * [ gr . Dropdown . update ( choices = modules . sd_models . checkpoint_tiles ( ) ) for _ in range ( 4 ) ] , f " Error merging checkpoints: { e } " ]
2022-09-29 02:50:34 +03:00
return results
2022-09-18 22:25:18 +03:00
2023-01-19 10:39:51 +03:00
modelmerger_merge . click ( fn = lambda : ' ' , inputs = [ ] , outputs = [ modelmerger_result ] )
2022-09-29 00:59:44 +03:00
modelmerger_merge . click (
2023-01-19 09:25:37 +03:00
fn = wrap_gradio_gpu_call ( modelmerger , extra_outputs = lambda : [ gr . update ( ) for _ in range ( 4 ) ] ) ,
_js = ' modelmerger ' ,
2022-09-29 00:59:44 +03:00
inputs = [
2023-01-19 09:25:37 +03:00
dummy_component ,
2022-09-29 00:59:44 +03:00
primary_model_name ,
secondary_model_name ,
2022-10-14 09:05:06 +03:00
tertiary_model_name ,
2022-09-29 00:59:44 +03:00
interp_method ,
interp_amount ,
save_as_half ,
2022-09-29 02:50:34 +03:00
custom_name ,
2022-11-27 15:51:29 +03:00
checkpoint_format ,
2023-01-11 09:10:07 +03:00
config_source ,
2023-01-19 10:39:51 +03:00
bake_in_vae ,
2023-01-22 10:17:12 +03:00
discard_weights ,
2022-09-29 00:59:44 +03:00
] ,
outputs = [
primary_model_name ,
secondary_model_name ,
2022-10-14 09:05:06 +03:00
tertiary_model_name ,
2022-09-29 00:59:44 +03:00
component_dict [ ' sd_model_checkpoint ' ] ,
2023-01-19 09:25:37 +03:00
modelmerger_result ,
2022-09-29 00:59:44 +03:00
]
)
2022-09-23 22:49:21 +03:00
2022-09-10 08:18:54 +03:00
ui_config_file = cmd_opts . ui_config_file
2022-09-04 13:52:01 +03:00
ui_settings = { }
settings_count = len ( ui_settings )
error_loading = False
try :
if os . path . exists ( ui_config_file ) :
with open ( ui_config_file , " r " , encoding = " utf8 " ) as file :
ui_settings = json . load ( file )
except Exception :
error_loading = True
print ( " Error loading settings: " , file = sys . stderr )
print ( traceback . format_exc ( ) , file = sys . stderr )
def loadsave ( path , x ) :
2022-10-18 12:51:57 +03:00
def apply_field ( obj , field , condition = None , init_field = None ) :
2022-09-04 13:52:01 +03:00
key = path + " / " + field
2022-09-25 08:56:50 +03:00
2022-10-29 10:56:19 +03:00
if getattr ( obj , ' custom_script_source ' , None ) is not None :
2022-09-25 08:56:50 +03:00
key = ' customscript/ ' + obj . custom_script_source + ' / ' + key
2022-10-10 04:26:52 +03:00
2022-09-25 19:43:42 +03:00
if getattr ( obj , ' do_not_save_to_config ' , False ) :
return
2022-10-10 04:26:52 +03:00
2022-09-04 13:52:01 +03:00
saved_value = ui_settings . get ( key , None )
if saved_value is None :
ui_settings [ key ] = getattr ( obj , field )
2022-10-15 23:09:11 +03:00
elif condition and not condition ( saved_value ) :
2023-01-18 23:04:24 +03:00
pass
# this warning is generally not useful;
# print(f'Warning: Bad ui setting value: {key}: {saved_value}; Default value "{getattr(obj, field)}" will be used instead.')
2022-10-15 23:09:11 +03:00
else :
2022-09-04 13:52:01 +03:00
setattr ( obj , field , saved_value )
2022-10-18 12:51:57 +03:00
if init_field is not None :
init_field ( saved_value )
2022-09-04 13:52:01 +03:00
2023-01-06 16:03:43 +03:00
if type ( x ) in [ gr . Slider , gr . Radio , gr . Checkbox , gr . Textbox , gr . Number , gr . Dropdown ] and x . visible :
2022-09-25 19:43:42 +03:00
apply_field ( x , ' visible ' )
2022-09-04 13:52:01 +03:00
if type ( x ) == gr . Slider :
apply_field ( x , ' value ' )
apply_field ( x , ' minimum ' )
apply_field ( x , ' maximum ' )
apply_field ( x , ' step ' )
if type ( x ) == gr . Radio :
2022-09-05 19:11:29 +03:00
apply_field ( x , ' value ' , lambda val : val in x . choices )
2022-09-04 13:52:01 +03:00
2022-09-25 08:31:02 +03:00
if type ( x ) == gr . Checkbox :
2022-09-25 08:40:37 +03:00
apply_field ( x , ' value ' )
2022-09-25 08:31:02 +03:00
if type ( x ) == gr . Textbox :
2022-09-25 08:40:37 +03:00
apply_field ( x , ' value ' )
2022-10-10 04:26:52 +03:00
2022-09-25 08:39:22 +03:00
if type ( x ) == gr . Number :
2022-09-25 08:40:37 +03:00
apply_field ( x , ' value ' )
2022-10-10 04:26:52 +03:00
2023-01-06 16:03:43 +03:00
if type ( x ) == gr . Dropdown :
2023-01-14 14:56:39 +03:00
def check_dropdown ( val ) :
2023-01-21 20:07:14 +03:00
if getattr ( x , ' multiselect ' , False ) :
2023-01-14 14:56:39 +03:00
return all ( [ value in x . choices for value in val ] )
else :
return val in x . choices
apply_field ( x , ' value ' , check_dropdown , getattr ( x , ' init_field ' , None ) )
2022-10-15 22:47:03 +03:00
2022-09-04 13:52:01 +03:00
visit ( txt2img_interface , loadsave , " txt2img " )
visit ( img2img_interface , loadsave , " img2img " )
2022-09-11 11:31:16 +03:00
visit ( extras_interface , loadsave , " extras " )
2022-10-17 19:56:23 +03:00
visit ( modelmerger_interface , loadsave , " modelmerger " )
2023-01-04 20:10:40 +03:00
visit ( train_interface , loadsave , " train " )
2022-09-04 13:52:01 +03:00
if not error_loading and ( not os . path . exists ( ui_config_file ) or settings_count != len ( ui_settings ) ) :
with open ( ui_config_file , " w " , encoding = " utf8 " ) as file :
json . dump ( ui_settings , file , indent = 4 )
2023-01-20 08:48:15 +03:00
# Required as a workaround for change() event not triggering when loading values from ui-config.json
interp_description . value = update_interp_description ( interp_method . value )
2022-09-03 12:08:45 +03:00
return demo
2022-11-08 08:35:01 +03:00
def reload_javascript ( ) :
2023-01-21 22:57:19 +03:00
head = f ' <script type= " text/javascript " src= " file= { os . path . abspath ( " script.js " ) } " ></script> \n '
2022-09-03 12:08:45 +03:00
2023-01-21 22:57:19 +03:00
inline = f " { localization . localization_js ( shared . opts . localization ) } ; "
2022-10-19 09:43:49 +03:00
if cmd_opts . theme is not None :
2023-01-21 22:57:19 +03:00
inline + = f " set_theme( ' { cmd_opts . theme } ' ); "
2022-09-03 12:08:45 +03:00
2023-01-21 22:57:19 +03:00
for script in modules . scripts . list_scripts ( " javascript " , " .js " ) :
head + = f ' <script type= " text/javascript " src= " file= { script . path } " ></script> \n '
2022-09-03 12:08:45 +03:00
2023-01-23 11:54:42 +03:00
head + = f ' <script type= " text/javascript " > { inline } </script> \n '
2022-10-02 21:26:06 +03:00
def template_response ( * args , * * kwargs ) :
2022-11-08 08:35:01 +03:00
res = shared . GradioTemplateResponseOriginal ( * args , * * kwargs )
2023-01-21 22:57:19 +03:00
res . body = res . body . replace ( b ' </head> ' , f ' { head } </head> ' . encode ( " utf8 " ) )
2022-10-02 21:26:06 +03:00
res . init_headers ( )
return res
gradio . routes . templates . TemplateResponse = template_response
2022-10-12 19:19:34 +03:00
2022-10-14 20:04:47 +03:00
2022-11-08 08:35:01 +03:00
if not hasattr ( shared , ' GradioTemplateResponseOriginal ' ) :
shared . GradioTemplateResponseOriginal = gradio . routes . templates . TemplateResponse
2023-01-05 11:57:01 +03:00
def versions_html ( ) :
import torch
import launch
python_version = " . " . join ( [ str ( x ) for x in sys . version_info [ 0 : 3 ] ] )
commit = launch . commit_hash ( )
short_commit = commit [ 0 : 8 ]
if shared . xformers_available :
import xformers
xformers_version = xformers . __version__
else :
xformers_version = " N/A "
return f """
python : < span title = " {sys.version} " > { python_version } < / span >
•
torch : { torch . __version__ }
•
xformers : { xformers_version }
•
gradio : { gr . __version__ }
•
commit : < a href = " https://github.com/AUTOMATIC1111/stable-diffusion-webui/commit/ {commit} " > { short_commit } < / a >
2023-01-14 15:55:40 +03:00
•
checkpoint : < a id = " sd_checkpoint_hash " > N / A < / a >
2023-01-05 11:57:01 +03:00
"""