diff --git a/script.js b/script.js index 1d7de7da..5d38b334 100644 --- a/script.js +++ b/script.js @@ -33,7 +33,7 @@ titles = { } function gradioApp(){ - return document.getElementsByTagName('gradio-app')[0]; + return document.getElementsByTagName('gradio-app')[0].shadowRoot; } function addTitles(root){ @@ -47,8 +47,33 @@ function addTitles(root){ document.addEventListener("DOMContentLoaded", function() { var mutationObserver = new MutationObserver(function(m){ - addTitles(gradioApp().shadowRoot); + addTitles(gradioApp()); }); - mutationObserver.observe( gradioApp().shadowRoot, { childList:true, subtree:true }) + mutationObserver.observe( gradioApp(), { childList:true, subtree:true }) }); + +function selected_gallery_index(){ + var gr = gradioApp() + var buttons = gradioApp().querySelectorAll(".gallery-item") + var button = gr.querySelector(".gallery-item.\\!ring-2") + + var result = -1 + buttons.forEach(function(v, i){ if(v==button) { result = i } }) + + return result +} + +function extract_image_from_gallery(gallery){ + if(gallery.length == 1){ + return gallery[0] + } + + index = selected_gallery_index() + + if (index < 0 || index >= gallery.length){ + return [] + } + + return gallery[index]; +} diff --git a/webui.py b/webui.py index c51a7829..f7a52107 100644 --- a/webui.py +++ b/webui.py @@ -1288,7 +1288,14 @@ def txt2img(prompt: str, negative_prompt: str, steps: int, sampler_index: int, u return processed.images, processed.js(), plaintext_to_html(processed.info) + def image_from_url_text(filedata): + if type(filedata) == list: + if len(filedata) == 0: + return None + + filedata = filedata[0] + if filedata.startswith("data:image/png;base64,"): filedata = filedata[len("data:image/png;base64,"):] @@ -1368,7 +1375,7 @@ with gr.Blocks(analytics_enabled=False) as txt2img_interface: with gr.Column(variant='panel'): with gr.Group(): - txt2img_gallery = gr.Gallery(label='Output') + txt2img_gallery = gr.Gallery(label='Output', elem_id='txt2img_gallery') with gr.Group(): with gr.Row(): @@ -1760,7 +1767,7 @@ with gr.Blocks(analytics_enabled=False) as img2img_interface: with gr.Column(variant='panel'): with gr.Group(): - img2img_gallery = gr.Gallery(label='Output') + img2img_gallery = gr.Gallery(label='Output', elem_id='img2img_gallery') with gr.Group(): with gr.Row(): @@ -1863,13 +1870,15 @@ with gr.Blocks(analytics_enabled=False) as img2img_interface: ) send_to_img2img.click( - fn=send_gradio_gallery_to_image, + fn=lambda x: image_from_url_text(x), + _js="extract_image_from_gallery", inputs=[txt2img_gallery], outputs=[init_img], ) send_to_inpaint.click( - fn=send_gradio_gallery_to_image, + fn=lambda x: image_from_url_text(x), + _js="extract_image_from_gallery", inputs=[txt2img_gallery], outputs=[init_img_with_mask], ) @@ -1952,14 +1961,17 @@ with gr.Blocks(analytics_enabled=False) as extras_interface: submit.click(**extras_args) + send_to_extras.click( - fn=send_gradio_gallery_to_image, + fn=lambda x: image_from_url_text(x), + _js="extract_image_from_gallery", inputs=[txt2img_gallery], outputs=[image], ) img2img_send_to_extras.click( - fn=send_gradio_gallery_to_image, + fn=lambda x: image_from_url_text(x), + _js="extract_image_from_gallery", inputs=[img2img_gallery], outputs=[image], )