fix broken ↙ button, fix field paste ignoring most of useful fields for for #3768

This commit is contained in:
AUTOMATIC 2022-10-29 10:56:19 +03:00
parent beb6fc2979
commit 35c45df28b
2 changed files with 42 additions and 39 deletions

View File

@ -6,7 +6,7 @@ import gradio as gr
from modules.shared import script_path from modules.shared import script_path
from modules import shared from modules import shared
import tempfile import tempfile
from PIL import Image, PngImagePlugin from PIL import Image
re_param_code = r'\s*([\w ]+):\s*("(?:\\|\"|[^\"])+"|[^,]*)(?:,|$)' re_param_code = r'\s*([\w ]+):\s*("(?:\\|\"|[^\"])+"|[^,]*)(?:,|$)'
re_param = re.compile(re_param_code) re_param = re.compile(re_param_code)
@ -61,6 +61,24 @@ def add_paste_fields(tabname, init_img, fields):
modules.ui.img2img_paste_fields = fields modules.ui.img2img_paste_fields = fields
def integrate_settings_paste_fields(component_dict):
from modules import ui
settings_map = {
'sd_hypernetwork': 'Hypernet',
'CLIP_stop_at_last_layers': 'Clip skip',
'sd_model_checkpoint': 'Model hash',
}
settings_paste_fields = [
(component_dict[k], lambda d, k=k, v=v: ui.apply_setting(k, d.get(v, None)))
for k, v in settings_map.items()
]
for tabname, info in paste_fields.items():
if info["fields"] is not None:
info["fields"] += settings_paste_fields
def create_buttons(tabs_list): def create_buttons(tabs_list):
buttons = {} buttons = {}
for tab in tabs_list: for tab in tabs_list:
@ -93,18 +111,16 @@ def run_bind():
) )
if send_generate_info and paste_fields[tab]["fields"] is not None: if send_generate_info and paste_fields[tab]["fields"] is not None:
paste_field_names = ['Prompt', 'Negative prompt', 'Steps', 'Face restoration', 'Size-1', 'Size-2']
if shared.opts.send_seed:
paste_field_names += ["Seed"]
if send_generate_info in paste_fields: if send_generate_info in paste_fields:
paste_field_names = ['Prompt', 'Negative prompt', 'Steps', 'Face restoration', 'Size-1', 'Size-2'] + (["Seed"] if shared.opts.send_seed else [])
button.click( button.click(
fn=lambda *x: x, fn=lambda *x: x,
inputs=[field for field, name in paste_fields[send_generate_info]["fields"] if name in paste_field_names], inputs=[field for field, name in paste_fields[send_generate_info]["fields"] if name in paste_field_names],
outputs=[field for field, name in paste_fields[tab]["fields"] if name in paste_field_names], outputs=[field for field, name in paste_fields[tab]["fields"] if name in paste_field_names],
) )
else: else:
connect_paste(button, [(field, name) for field, name in paste_fields[tab]["fields"] if name in paste_field_names], send_generate_info) connect_paste(button, paste_fields[tab]["fields"], send_generate_info)
button.click( button.click(
fn=None, fn=None,

View File

@ -589,6 +589,7 @@ def create_refresh_button(refresh_component, refresh_method, refreshed_args, ele
) )
return refresh_button return refresh_button
def create_output_panel(tabname, outdir): def create_output_panel(tabname, outdir):
def open_folder(f): def open_folder(f):
if not os.path.exists(f): if not os.path.exists(f):
@ -716,6 +717,7 @@ def create_ui(wrap_gradio_gpu_call):
custom_inputs = modules.scripts.scripts_txt2img.setup_ui(is_img2img=False) custom_inputs = modules.scripts.scripts_txt2img.setup_ui(is_img2img=False)
txt2img_gallery, generation_info, html_info = create_output_panel("txt2img", opts.outdir_txt2img_samples) txt2img_gallery, generation_info, html_info = create_output_panel("txt2img", opts.outdir_txt2img_samples)
parameters_copypaste.bind_buttons({"txt2img": txt2img_paste}, None, txt2img_prompt)
connect_reuse_seed(seed, reuse_seed, generation_info, dummy_component, is_subseed=False) connect_reuse_seed(seed, reuse_seed, generation_info, dummy_component, is_subseed=False)
connect_reuse_seed(subseed, reuse_subseed, generation_info, dummy_component, is_subseed=True) connect_reuse_seed(subseed, reuse_subseed, generation_info, dummy_component, is_subseed=True)
@ -784,7 +786,7 @@ def create_ui(wrap_gradio_gpu_call):
] ]
) )
parameters_copypaste.add_paste_fields("txt2img", None, [ txt2img_paste_fields = [
(txt2img_prompt, "Prompt"), (txt2img_prompt, "Prompt"),
(txt2img_negative_prompt, "Negative prompt"), (txt2img_negative_prompt, "Negative prompt"),
(steps, "Steps"), (steps, "Steps"),
@ -805,7 +807,8 @@ def create_ui(wrap_gradio_gpu_call):
(firstphase_width, "First pass size-1"), (firstphase_width, "First pass size-1"),
(firstphase_height, "First pass size-2"), (firstphase_height, "First pass size-2"),
*modules.scripts.scripts_txt2img.infotext_fields *modules.scripts.scripts_txt2img.infotext_fields
]) ]
parameters_copypaste.add_paste_fields("txt2img", None, txt2img_paste_fields)
txt2img_preview_params = [ txt2img_preview_params = [
txt2img_prompt, txt2img_prompt,
@ -893,6 +896,7 @@ def create_ui(wrap_gradio_gpu_call):
custom_inputs = modules.scripts.scripts_img2img.setup_ui(is_img2img=True) custom_inputs = modules.scripts.scripts_img2img.setup_ui(is_img2img=True)
img2img_gallery, generation_info, html_info = create_output_panel("img2img", opts.outdir_img2img_samples) img2img_gallery, generation_info, html_info = create_output_panel("img2img", opts.outdir_img2img_samples)
parameters_copypaste.bind_buttons({"img2img": img2img_paste}, None, img2img_prompt)
connect_reuse_seed(seed, reuse_seed, generation_info, dummy_component, is_subseed=False) connect_reuse_seed(seed, reuse_seed, generation_info, dummy_component, is_subseed=False)
connect_reuse_seed(subseed, reuse_subseed, generation_info, dummy_component, is_subseed=True) connect_reuse_seed(subseed, reuse_subseed, generation_info, dummy_component, is_subseed=True)
@ -1038,7 +1042,6 @@ def create_ui(wrap_gradio_gpu_call):
parameters_copypaste.add_paste_fields("img2img", init_img, img2img_paste_fields) parameters_copypaste.add_paste_fields("img2img", init_img, img2img_paste_fields)
parameters_copypaste.add_paste_fields("inpaint", init_img_with_mask, img2img_paste_fields) parameters_copypaste.add_paste_fields("inpaint", init_img_with_mask, img2img_paste_fields)
with gr.Blocks(analytics_enabled=False) as extras_interface: with gr.Blocks(analytics_enabled=False) as extras_interface:
with gr.Row().style(equal_height=False): with gr.Row().style(equal_height=False):
with gr.Column(variant='panel'): with gr.Column(variant='panel'):
@ -1050,12 +1053,8 @@ def create_ui(wrap_gradio_gpu_call):
image_batch = gr.File(label="Batch Process", file_count="multiple", interactive=True, type="file") image_batch = gr.File(label="Batch Process", file_count="multiple", interactive=True, type="file")
with gr.TabItem('Batch from Directory'): with gr.TabItem('Batch from Directory'):
extras_batch_input_dir = gr.Textbox(label="Input directory", **shared.hide_dirs, extras_batch_input_dir = gr.Textbox(label="Input directory", **shared.hide_dirs, placeholder="A directory on the same machine where the server is running.")
placeholder="A directory on the same machine where the server is running." extras_batch_output_dir = gr.Textbox(label="Output directory", **shared.hide_dirs, placeholder="Leave blank to save images to the default path.")
)
extras_batch_output_dir = gr.Textbox(label="Output directory", **shared.hide_dirs,
placeholder="Leave blank to save images to the default path."
)
show_extras_results = gr.Checkbox(label='Show result images', value=True) show_extras_results = gr.Checkbox(label='Show result images', value=True)
with gr.Tabs(elem_id="extras_resize_mode"): with gr.Tabs(elem_id="extras_resize_mode"):
@ -1087,7 +1086,6 @@ def create_ui(wrap_gradio_gpu_call):
submit = gr.Button('Generate', elem_id="extras_generate", variant='primary') submit = gr.Button('Generate', elem_id="extras_generate", variant='primary')
result_images, html_info_x, html_info = create_output_panel("extras", opts.outdir_extras_samples) result_images, html_info_x, html_info = create_output_panel("extras", opts.outdir_extras_samples)
submit.click( submit.click(
@ -1121,7 +1119,6 @@ def create_ui(wrap_gradio_gpu_call):
) )
parameters_copypaste.add_paste_fields("extras", extras_image, None) parameters_copypaste.add_paste_fields("extras", extras_image, None)
extras_image.change( extras_image.change(
fn=modules.extras.clear_cache, fn=modules.extras.clear_cache,
inputs=[], outputs=[] inputs=[], outputs=[]
@ -1587,9 +1584,6 @@ def create_ui(wrap_gradio_gpu_call):
if column is not None: if column is not None:
column.__exit__() column.__exit__()
interfaces = [ interfaces = [
(txt2img_interface, "txt2img", "txt2img"), (txt2img_interface, "txt2img", "txt2img"),
(img2img_interface, "img2img", "img2img"), (img2img_interface, "img2img", "img2img"),
@ -1599,10 +1593,6 @@ def create_ui(wrap_gradio_gpu_call):
(train_interface, "Train", "ti"), (train_interface, "Train", "ti"),
] ]
interfaces += script_callbacks.ui_tabs_callback()
interfaces += [(settings_interface, "Settings", "settings")]
css = "" css = ""
for cssfile in modules.scripts.list_files_with_name("style.css"): for cssfile in modules.scripts.list_files_with_name("style.css"):
@ -1619,6 +1609,9 @@ def create_ui(wrap_gradio_gpu_call):
if not cmd_opts.no_progressbar_hiding: if not cmd_opts.no_progressbar_hiding:
css += css_hide_progressbar css += css_hide_progressbar
interfaces += script_callbacks.ui_tabs_callback()
interfaces += [(settings_interface, "Settings", "settings")]
with gr.Blocks(css=css, analytics_enabled=False, title="Stable Diffusion") as demo: with gr.Blocks(css=css, analytics_enabled=False, title="Stable Diffusion") as demo:
with gr.Row(elem_id="quicksettings"): with gr.Row(elem_id="quicksettings"):
for i, k, item in quicksettings_list: for i, k, item in quicksettings_list:
@ -1627,6 +1620,9 @@ def create_ui(wrap_gradio_gpu_call):
settings_interface.gradio_ref = demo settings_interface.gradio_ref = demo
parameters_copypaste.integrate_settings_paste_fields(component_dict)
parameters_copypaste.run_bind()
with gr.Tabs(elem_id="tabs") as tabs: with gr.Tabs(elem_id="tabs") as tabs:
for interface, label, ifid in interfaces: for interface, label, ifid in interfaces:
with gr.TabItem(label, id=ifid, elem_id='tab_' + ifid): with gr.TabItem(label, id=ifid, elem_id='tab_' + ifid):
@ -1681,15 +1677,6 @@ def create_ui(wrap_gradio_gpu_call):
] ]
) )
settings_map = {
'sd_hypernetwork': 'Hypernet',
'CLIP_stop_at_last_layers': 'Clip skip',
'sd_model_checkpoint': 'Model hash',
}
parameters_copypaste.run_bind()
ui_config_file = cmd_opts.ui_config_file ui_config_file = cmd_opts.ui_config_file
ui_settings = {} ui_settings = {}
settings_count = len(ui_settings) settings_count = len(ui_settings)