2022-09-03 12:08:45 +03:00
import base64
import html
import io
import json
import mimetypes
import os
import sys
import time
import traceback
from PIL import Image
import gradio as gr
import gradio . utils
2022-09-03 17:21:15 +03:00
import gradio . routes
2022-09-03 12:08:45 +03:00
from modules . paths import script_path
from modules . shared import opts , cmd_opts
import modules . shared as shared
from modules . sd_samplers import samplers , samplers_for_img2img
import modules . gfpgan_model as gfpgan
import modules . realesrgan_model as realesrgan
2022-09-03 17:21:15 +03:00
import modules . scripts
2022-09-03 12:08:45 +03:00
# this is a fix for Windows users. Without it, javascript files will be served with text/html content-type and the bowser will not show any UI
mimetypes . init ( )
mimetypes . add_type ( ' application/javascript ' , ' .js ' )
if not cmd_opts . share :
# fix gradio phoning home
gradio . utils . version_check = lambda : None
gradio . utils . get_local_ip_address = lambda : ' 127.0.0.1 '
def gr_show ( visible = True ) :
return { " visible " : visible , " __type__ " : " update " }
sample_img2img = " assets/stable-samples/img2img/sketch-mountains-input.jpg "
sample_img2img = sample_img2img if os . path . exists ( sample_img2img ) else None
css_hide_progressbar = """
. wrap . m - 12 svg { display : none ! important ; }
. wrap . m - 12 : : before { content : " Loading... " }
. progress - bar { display : none ! important ; }
. meta - text { display : none ! important ; }
"""
def plaintext_to_html ( text ) :
text = " " . join ( [ f " <p> { html . escape ( x ) } </p> \n " for x in text . split ( ' \n ' ) ] )
return text
def image_from_url_text ( filedata ) :
if type ( filedata ) == list :
if len ( filedata ) == 0 :
return None
filedata = filedata [ 0 ]
if filedata . startswith ( " data:image/png;base64, " ) :
filedata = filedata [ len ( " data:image/png;base64, " ) : ]
filedata = base64 . decodebytes ( filedata . encode ( ' utf-8 ' ) )
image = Image . open ( io . BytesIO ( filedata ) )
return image
def send_gradio_gallery_to_image ( x ) :
if len ( x ) == 0 :
return None
return image_from_url_text ( x [ 0 ] )
def save_files ( js_data , images ) :
import csv
os . makedirs ( opts . outdir_save , exist_ok = True )
filenames = [ ]
data = json . loads ( js_data )
2022-09-03 21:33:47 +03:00
with open ( os . path . join ( opts . outdir_save , " log.csv " ) , " a " , encoding = " utf8 " , newline = ' ' ) as file :
2022-09-03 12:08:45 +03:00
at_start = file . tell ( ) == 0
writer = csv . writer ( file )
if at_start :
writer . writerow ( [ " prompt " , " seed " , " width " , " height " , " sampler " , " cfgs " , " steps " , " filename " ] )
filename_base = str ( int ( time . time ( ) * 1000 ) )
for i , filedata in enumerate ( images ) :
filename = filename_base + ( " " if len ( images ) == 1 else " - " + str ( i + 1 ) ) + " .png "
filepath = os . path . join ( opts . outdir_save , filename )
if filedata . startswith ( " data:image/png;base64, " ) :
filedata = filedata [ len ( " data:image/png;base64, " ) : ]
with open ( filepath , " wb " ) as imgfile :
imgfile . write ( base64 . decodebytes ( filedata . encode ( ' utf-8 ' ) ) )
filenames . append ( filename )
writer . writerow ( [ data [ " prompt " ] , data [ " seed " ] , data [ " width " ] , data [ " height " ] , data [ " sampler " ] , data [ " cfg_scale " ] , data [ " steps " ] , filenames [ 0 ] ] )
return ' ' , ' ' , plaintext_to_html ( f " Saved: { filenames [ 0 ] } " )
def wrap_gradio_call ( func ) :
def f ( * args , * * kwargs ) :
t = time . perf_counter ( )
try :
res = list ( func ( * args , * * kwargs ) )
except Exception as e :
print ( " Error completing request " , file = sys . stderr )
print ( " Arguments: " , args , kwargs , file = sys . stderr )
print ( traceback . format_exc ( ) , file = sys . stderr )
res = [ None , ' ' , f " <div class= ' error ' > { plaintext_to_html ( type ( e ) . __name__ + ' : ' + str ( e ) ) } </div> " ]
elapsed = time . perf_counter ( ) - t
# last item is always HTML
res [ - 1 ] = res [ - 1 ] + f " <p class= ' performance ' >Time taken: { elapsed : .2f } s</p> "
shared . state . interrupted = False
return tuple ( res )
return f
2022-09-04 13:52:01 +03:00
def visit ( x , func , path = " " ) :
if hasattr ( x , ' children ' ) :
for c in x . children :
visit ( c , func , path )
elif x . label is not None :
func ( path + " / " + str ( x . label ) , x )
2022-09-03 12:08:45 +03:00
2022-09-04 13:52:01 +03:00
def create_ui ( txt2img , img2img , run_extras , run_pnginfo ) :
2022-09-03 12:08:45 +03:00
with gr . Blocks ( analytics_enabled = False ) as txt2img_interface :
with gr . Row ( ) :
prompt = gr . Textbox ( label = " Prompt " , elem_id = " txt2img_prompt " , show_label = False , placeholder = " Prompt " , lines = 1 )
negative_prompt = gr . Textbox ( label = " Negative prompt " , elem_id = " txt2img_negative_prompt " , show_label = False , placeholder = " Negative prompt " , lines = 1 , visible = False )
submit = gr . Button ( ' Generate ' , elem_id = " txt2img_generate " , variant = ' primary ' )
with gr . Row ( ) . style ( equal_height = False ) :
with gr . Column ( variant = ' panel ' ) :
steps = gr . Slider ( minimum = 1 , maximum = 150 , step = 1 , label = " Sampling Steps " , value = 20 )
sampler_index = gr . Radio ( label = ' Sampling method ' , elem_id = " txt2img_sampling " , choices = [ x . name for x in samplers ] , value = samplers [ 0 ] . name , type = " index " )
with gr . Row ( ) :
2022-09-03 17:21:15 +03:00
use_gfpgan = gr . Checkbox ( label = ' GFPGAN ' , value = False , visible = gfpgan . have_gfpgan )
2022-09-03 12:08:45 +03:00
with gr . Row ( ) :
batch_count = gr . Slider ( minimum = 1 , maximum = cmd_opts . max_batch_count , step = 1 , label = ' Batch count ' , value = 1 )
batch_size = gr . Slider ( minimum = 1 , maximum = 8 , step = 1 , label = ' Batch size ' , value = 1 )
cfg_scale = gr . Slider ( minimum = 1.0 , maximum = 15.0 , step = 0.5 , label = ' CFG Scale ' , value = 7.0 )
with gr . Group ( ) :
height = gr . Slider ( minimum = 64 , maximum = 2048 , step = 64 , label = " Height " , value = 512 )
width = gr . Slider ( minimum = 64 , maximum = 2048 , step = 64 , label = " Width " , value = 512 )
seed = gr . Number ( label = ' Seed ' , value = - 1 )
2022-09-03 17:21:15 +03:00
with gr . Group ( ) :
2022-09-04 01:29:43 +03:00
custom_inputs = modules . scripts . scripts_txt2img . setup_ui ( is_img2img = False )
2022-09-03 12:08:45 +03:00
with gr . Column ( variant = ' panel ' ) :
with gr . Group ( ) :
txt2img_gallery = gr . Gallery ( label = ' Output ' , elem_id = ' txt2img_gallery ' )
with gr . Group ( ) :
with gr . Row ( ) :
save = gr . Button ( ' Save ' )
send_to_img2img = gr . Button ( ' Send to img2img ' )
send_to_inpaint = gr . Button ( ' Send to inpaint ' )
send_to_extras = gr . Button ( ' Send to extras ' )
interrupt = gr . Button ( ' Interrupt ' )
with gr . Group ( ) :
html_info = gr . HTML ( )
generation_info = gr . Textbox ( visible = False )
txt2img_args = dict (
fn = txt2img ,
inputs = [
prompt ,
negative_prompt ,
steps ,
sampler_index ,
2022-09-03 17:21:15 +03:00
use_gfpgan ,
2022-09-03 12:08:45 +03:00
batch_count ,
batch_size ,
cfg_scale ,
seed ,
height ,
width ,
2022-09-03 17:21:15 +03:00
] + custom_inputs ,
2022-09-03 12:08:45 +03:00
outputs = [
txt2img_gallery ,
generation_info ,
html_info
]
)
prompt . submit ( * * txt2img_args )
submit . click ( * * txt2img_args )
interrupt . click (
fn = lambda : shared . state . interrupt ( ) ,
inputs = [ ] ,
outputs = [ ] ,
)
save . click (
fn = wrap_gradio_call ( save_files ) ,
inputs = [
generation_info ,
txt2img_gallery ,
] ,
outputs = [
html_info ,
html_info ,
html_info ,
]
)
with gr . Blocks ( analytics_enabled = False ) as img2img_interface :
with gr . Row ( ) :
prompt = gr . Textbox ( label = " Prompt " , elem_id = " img2img_prompt " , show_label = False , placeholder = " Prompt " , lines = 1 )
submit = gr . Button ( ' Generate ' , elem_id = " img2img_generate " , variant = ' primary ' )
with gr . Row ( ) . style ( equal_height = False ) :
with gr . Column ( variant = ' panel ' ) :
with gr . Group ( ) :
switch_mode = gr . Radio ( label = ' Mode ' , elem_id = " img2img_mode " , choices = [ ' Redraw whole image ' , ' Inpaint a part of image ' , ' Loopback ' , ' SD upscale ' ] , value = ' Redraw whole image ' , type = " index " , show_label = False )
init_img = gr . Image ( label = " Image for img2img " , source = " upload " , interactive = True , type = " pil " )
init_img_with_mask = gr . Image ( label = " Image for inpainting with mask " , elem_id = " img2maskimg " , source = " upload " , interactive = True , type = " pil " , tool = " sketch " , visible = False )
resize_mode = gr . Radio ( label = " Resize mode " , show_label = False , choices = [ " Just resize " , " Crop and resize " , " Resize and fill " ] , type = " index " , value = " Just resize " )
steps = gr . Slider ( minimum = 1 , maximum = 150 , step = 1 , label = " Sampling Steps " , value = 20 )
sampler_index = gr . Radio ( label = ' Sampling method ' , choices = [ x . name for x in samplers_for_img2img ] , value = samplers_for_img2img [ 0 ] . name , type = " index " )
mask_blur = gr . Slider ( label = ' Mask blur ' , minimum = 0 , maximum = 64 , step = 1 , value = 4 , visible = False )
2022-09-03 21:02:38 +03:00
inpainting_fill = gr . Radio ( label = ' Masked content ' , choices = [ ' fill ' , ' original ' , ' latent noise ' , ' latent nothing ' ] , value = ' fill ' , type = " index " , visible = False )
2022-09-03 12:08:45 +03:00
with gr . Row ( ) :
2022-09-04 01:29:43 +03:00
inpaint_full_res = gr . Checkbox ( label = ' Inpaint at full resolution ' , value = False , visible = False )
2022-09-03 21:02:38 +03:00
inpainting_mask_invert = gr . Radio ( label = ' Masking mode ' , choices = [ ' Inpaint masked ' , ' Inpaint not masked ' ] , value = ' Inpaint masked ' , type = " index " , visible = False )
with gr . Row ( ) :
use_gfpgan = gr . Checkbox ( label = ' GFPGAN ' , value = False , visible = gfpgan . have_gfpgan )
2022-09-03 12:08:45 +03:00
with gr . Row ( ) :
sd_upscale_upscaler_name = gr . Radio ( label = ' Upscaler ' , choices = list ( shared . sd_upscalers . keys ( ) ) , value = list ( shared . sd_upscalers . keys ( ) ) [ 0 ] , visible = False )
sd_upscale_overlap = gr . Slider ( minimum = 0 , maximum = 256 , step = 16 , label = ' Tile overlap ' , value = 64 , visible = False )
with gr . Row ( ) :
batch_count = gr . Slider ( minimum = 1 , maximum = cmd_opts . max_batch_count , step = 1 , label = ' Batch count ' , value = 1 )
batch_size = gr . Slider ( minimum = 1 , maximum = 8 , step = 1 , label = ' Batch size ' , value = 1 )
with gr . Group ( ) :
cfg_scale = gr . Slider ( minimum = 1.0 , maximum = 15.0 , step = 0.5 , label = ' CFG Scale ' , value = 7.0 )
denoising_strength = gr . Slider ( minimum = 0.0 , maximum = 1.0 , step = 0.01 , label = ' Denoising Strength ' , value = 0.75 )
with gr . Group ( ) :
height = gr . Slider ( minimum = 64 , maximum = 2048 , step = 64 , label = " Height " , value = 512 )
width = gr . Slider ( minimum = 64 , maximum = 2048 , step = 64 , label = " Width " , value = 512 )
seed = gr . Number ( label = ' Seed ' , value = - 1 )
2022-09-03 17:21:15 +03:00
with gr . Group ( ) :
2022-09-04 01:29:43 +03:00
custom_inputs = modules . scripts . scripts_img2img . setup_ui ( is_img2img = True )
2022-09-03 17:21:15 +03:00
2022-09-03 12:08:45 +03:00
with gr . Column ( variant = ' panel ' ) :
with gr . Group ( ) :
img2img_gallery = gr . Gallery ( label = ' Output ' , elem_id = ' img2img_gallery ' )
with gr . Group ( ) :
with gr . Row ( ) :
interrupt = gr . Button ( ' Interrupt ' )
save = gr . Button ( ' Save ' )
img2img_send_to_extras = gr . Button ( ' Send to extras ' )
with gr . Group ( ) :
html_info = gr . HTML ( )
generation_info = gr . Textbox ( visible = False )
def apply_mode ( mode ) :
is_classic = mode == 0
is_inpaint = mode == 1
is_loopback = mode == 2
is_upscale = mode == 3
return {
init_img : gr_show ( not is_inpaint ) ,
init_img_with_mask : gr_show ( is_inpaint ) ,
mask_blur : gr_show ( is_inpaint ) ,
inpainting_fill : gr_show ( is_inpaint ) ,
batch_count : gr_show ( not is_upscale ) ,
batch_size : gr_show ( not is_loopback ) ,
sd_upscale_upscaler_name : gr_show ( is_upscale ) ,
2022-09-03 17:21:15 +03:00
sd_upscale_overlap : gr_show ( is_upscale ) ,
2022-09-03 12:08:45 +03:00
inpaint_full_res : gr_show ( is_inpaint ) ,
2022-09-03 21:02:38 +03:00
inpainting_mask_invert : gr_show ( is_inpaint ) ,
2022-09-03 12:08:45 +03:00
}
switch_mode . change (
apply_mode ,
inputs = [ switch_mode ] ,
outputs = [
init_img ,
init_img_with_mask ,
mask_blur ,
inpainting_fill ,
batch_count ,
batch_size ,
sd_upscale_upscaler_name ,
sd_upscale_overlap ,
inpaint_full_res ,
2022-09-03 21:02:38 +03:00
inpainting_mask_invert ,
2022-09-03 12:08:45 +03:00
]
)
img2img_args = dict (
fn = img2img ,
inputs = [
prompt ,
init_img ,
init_img_with_mask ,
steps ,
sampler_index ,
mask_blur ,
inpainting_fill ,
2022-09-03 17:21:15 +03:00
use_gfpgan ,
2022-09-03 12:08:45 +03:00
switch_mode ,
batch_count ,
batch_size ,
cfg_scale ,
denoising_strength ,
seed ,
height ,
width ,
resize_mode ,
sd_upscale_upscaler_name ,
sd_upscale_overlap ,
inpaint_full_res ,
2022-09-03 21:02:38 +03:00
inpainting_mask_invert ,
2022-09-03 17:21:15 +03:00
] + custom_inputs ,
2022-09-03 12:08:45 +03:00
outputs = [
img2img_gallery ,
generation_info ,
html_info
]
)
prompt . submit ( * * img2img_args )
submit . click ( * * img2img_args )
interrupt . click (
fn = lambda : shared . state . interrupt ( ) ,
inputs = [ ] ,
outputs = [ ] ,
)
save . click (
fn = wrap_gradio_call ( save_files ) ,
inputs = [
generation_info ,
img2img_gallery ,
] ,
outputs = [
html_info ,
html_info ,
html_info ,
]
)
send_to_img2img . click (
fn = lambda x : image_from_url_text ( x ) ,
_js = " extract_image_from_gallery " ,
inputs = [ txt2img_gallery ] ,
outputs = [ init_img ] ,
)
send_to_inpaint . click (
fn = lambda x : image_from_url_text ( x ) ,
_js = " extract_image_from_gallery " ,
inputs = [ txt2img_gallery ] ,
outputs = [ init_img_with_mask ] ,
)
with gr . Blocks ( analytics_enabled = False ) as extras_interface :
with gr . Row ( ) . style ( equal_height = False ) :
with gr . Column ( variant = ' panel ' ) :
with gr . Group ( ) :
image = gr . Image ( label = " Source " , source = " upload " , interactive = True , type = " pil " )
gfpgan_strength = gr . Slider ( minimum = 0.0 , maximum = 1.0 , step = 0.001 , label = " GFPGAN strength " , value = 1 , interactive = gfpgan . have_gfpgan )
realesrgan_resize = gr . Slider ( minimum = 1.0 , maximum = 4.0 , step = 0.05 , label = " Real-ESRGAN upscaling " , value = 2 , interactive = realesrgan . have_realesrgan )
realesrgan_model = gr . Radio ( label = ' Real-ESRGAN model ' , choices = [ x . name for x in realesrgan . realesrgan_models ] , value = realesrgan . realesrgan_models [ 0 ] . name , type = " index " , interactive = realesrgan . have_realesrgan )
submit = gr . Button ( ' Generate ' , elem_id = " extras_generate " , variant = ' primary ' )
with gr . Column ( variant = ' panel ' ) :
result_image = gr . Image ( label = " Result " )
html_info_x = gr . HTML ( )
html_info = gr . HTML ( )
extras_args = dict (
fn = run_extras ,
inputs = [
image ,
gfpgan_strength ,
realesrgan_resize ,
realesrgan_model ,
] ,
outputs = [
result_image ,
html_info_x ,
html_info ,
]
)
submit . click ( * * extras_args )
send_to_extras . click (
fn = lambda x : image_from_url_text ( x ) ,
_js = " extract_image_from_gallery " ,
inputs = [ txt2img_gallery ] ,
outputs = [ image ] ,
)
img2img_send_to_extras . click (
fn = lambda x : image_from_url_text ( x ) ,
_js = " extract_image_from_gallery " ,
inputs = [ img2img_gallery ] ,
outputs = [ image ] ,
)
pnginfo_interface = gr . Interface (
wrap_gradio_call ( run_pnginfo ) ,
inputs = [
gr . Image ( label = " Source " , source = " upload " , interactive = True , type = " pil " ) ,
] ,
outputs = [
gr . HTML ( ) ,
gr . HTML ( ) ,
gr . HTML ( ) ,
] ,
allow_flagging = " never " ,
analytics_enabled = False ,
)
def create_setting_component ( key ) :
def fun ( ) :
return opts . data [ key ] if key in opts . data else opts . data_labels [ key ] . default
info = opts . data_labels [ key ]
t = type ( info . default )
if info . component is not None :
item = info . component ( label = info . label , value = fun , * * ( info . component_args or { } ) )
elif t == str :
item = gr . Textbox ( label = info . label , value = fun , lines = 1 )
elif t == int :
item = gr . Number ( label = info . label , value = fun )
elif t == bool :
item = gr . Checkbox ( label = info . label , value = fun )
else :
raise Exception ( f ' bad options item type: { str ( t ) } for key { key } ' )
return item
def run_settings ( * args ) :
up = [ ]
for key , value , comp in zip ( opts . data_labels . keys ( ) , args , settings_interface . input_components ) :
opts . data [ key ] = value
up . append ( comp . update ( value = value ) )
opts . save ( shared . config_filename )
return ' Settings saved. ' , ' ' , ' '
settings_interface = gr . Interface (
run_settings ,
inputs = [ create_setting_component ( key ) for key in opts . data_labels . keys ( ) ] ,
outputs = [
gr . Textbox ( label = ' Result ' ) ,
gr . HTML ( ) ,
gr . HTML ( ) ,
] ,
title = None ,
description = None ,
allow_flagging = " never " ,
analytics_enabled = False ,
)
interfaces = [
( txt2img_interface , " txt2img " ) ,
( img2img_interface , " img2img " ) ,
( extras_interface , " Extras " ) ,
( pnginfo_interface , " PNG Info " ) ,
( settings_interface , " Settings " ) ,
]
with open ( os . path . join ( script_path , " style.css " ) , " r " , encoding = " utf8 " ) as file :
css = file . read ( )
if not cmd_opts . no_progressbar_hiding :
css + = css_hide_progressbar
demo = gr . TabbedInterface (
interface_list = [ x [ 0 ] for x in interfaces ] ,
tab_names = [ x [ 1 ] for x in interfaces ] ,
analytics_enabled = False ,
css = css ,
)
2022-09-04 13:52:01 +03:00
ui_config_file = os . path . join ( modules . paths . script_path , ' ui-config.json ' )
ui_settings = { }
settings_count = len ( ui_settings )
error_loading = False
try :
if os . path . exists ( ui_config_file ) :
with open ( ui_config_file , " r " , encoding = " utf8 " ) as file :
ui_settings = json . load ( file )
except Exception :
error_loading = True
print ( " Error loading settings: " , file = sys . stderr )
print ( traceback . format_exc ( ) , file = sys . stderr )
def loadsave ( path , x ) :
def apply_field ( obj , field ) :
key = path + " / " + field
saved_value = ui_settings . get ( key , None )
if saved_value is None :
ui_settings [ key ] = getattr ( obj , field )
else :
setattr ( obj , field , saved_value )
if type ( x ) == gr . Slider :
apply_field ( x , ' value ' )
apply_field ( x , ' minimum ' )
apply_field ( x , ' maximum ' )
apply_field ( x , ' step ' )
if type ( x ) == gr . Radio :
apply_field ( x , ' value ' )
visit ( txt2img_interface , loadsave , " txt2img " )
visit ( img2img_interface , loadsave , " img2img " )
if not error_loading and ( not os . path . exists ( ui_config_file ) or settings_count != len ( ui_settings ) ) :
with open ( ui_config_file , " w " , encoding = " utf8 " ) as file :
json . dump ( ui_settings , file , indent = 4 )
2022-09-03 12:08:45 +03:00
return demo
2022-09-03 17:21:15 +03:00
with open ( os . path . join ( script_path , " script.js " ) , " r " , encoding = " utf8 " ) as jsfile :
javascript = jsfile . read ( )
2022-09-03 12:08:45 +03:00
2022-09-03 17:21:15 +03:00
def template_response ( * args , * * kwargs ) :
res = gradio_routes_templates_response ( * args , * * kwargs )
res . body = res . body . replace ( b ' </head> ' , f ' <script> { javascript } </script></head> ' . encode ( " utf8 " ) )
res . init_headers ( )
return res
2022-09-03 12:08:45 +03:00
2022-09-03 17:21:15 +03:00
gradio_routes_templates_response = gradio . routes . templates . TemplateResponse
gradio . routes . templates . TemplateResponse = template_response