mirror of
https://github.com/sd-webui/stable-diffusion-webui.git
synced 2024-12-14 14:52:31 +03:00
img2img-fix (#717)
This commit is contained in:
parent
f28255466b
commit
70d4b1ca2a
@ -15,14 +15,9 @@ def css(opt):
|
||||
# TODO: @altryne restore this before merge
|
||||
if not opt.no_progressbar_hiding:
|
||||
styling += readTextFile("css", "no_progress_bar.css")
|
||||
if opt.custom_css:
|
||||
try:
|
||||
styling += readTextFile("css", "custom.css")
|
||||
print("Custom CSS loaded")
|
||||
except:
|
||||
pass
|
||||
return styling
|
||||
|
||||
|
||||
def js(opt):
|
||||
data = readTextFile("js", "index.js")
|
||||
data = "(z) => {" + data + "; return z ?? [] }"
|
||||
|
@ -57,21 +57,20 @@ def draw_gradio_ui(opt, img2img=lambda x: x, txt2img=lambda x: x,imgproc=lambda
|
||||
output_txt2img_params = gr.Highlightedtext(label="Generation parameters", interactive=False, elem_id='highlight')
|
||||
with gr.Group():
|
||||
with gr.Row(elem_id='txt2img_output_row'):
|
||||
output_txt2img_copy_params = gr.Button("Copy all").click(
|
||||
output_txt2img_copy_params = gr.Button("Copy full parameters").click(
|
||||
inputs=[output_txt2img_params], outputs=[],
|
||||
_js=js_copy_txt2img_output,
|
||||
fn=None, show_progress=False)
|
||||
output_txt2img_seed = gr.Number(label='Seed', interactive=False, visible=False)
|
||||
output_txt2img_copy_seed = gr.Button("Copy seed").click(
|
||||
output_txt2img_copy_seed = gr.Button("Copy only seed").click(
|
||||
inputs=[output_txt2img_seed], outputs=[],
|
||||
_js='(x) => navigator.clipboard.writeText(x)', fn=None, show_progress=False)
|
||||
output_txt2img_stats = gr.HTML(label='Stats')
|
||||
with gr.Column():
|
||||
with gr.Row():
|
||||
|
||||
txt2img_steps = gr.Slider(minimum=1, maximum=250, step=1, label="Sampling Steps",
|
||||
value=txt2img_defaults['ddim_steps'])
|
||||
txt2img_sampling = gr.Dropdown(label='Sampling method (k_lms is default k-diffusion sampler)',
|
||||
txt2img_steps = gr.Slider(minimum=1, maximum=250, step=1, label="Sampling Steps",
|
||||
value=txt2img_defaults['ddim_steps'])
|
||||
txt2img_sampling = gr.Dropdown(label='Sampling method (k_lms is default k-diffusion sampler)',
|
||||
choices=["DDIM", "PLMS", 'k_dpm_2_a', 'k_dpm_2', 'k_euler_a',
|
||||
'k_euler', 'k_heun', 'k_lms'],
|
||||
value=txt2img_defaults['sampler_name'])
|
||||
@ -158,28 +157,22 @@ def draw_gradio_ui(opt, img2img=lambda x: x, txt2img=lambda x: x,imgproc=lambda
|
||||
img2img_btn_editor = gr.Button("Generate", variant="primary", elem_id="img2img_edit_btn")
|
||||
with gr.Row().style(equal_height=False):
|
||||
with gr.Column():
|
||||
with gr.Tabs():
|
||||
with gr.TabItem("Img2Img Input"):
|
||||
#gr.Markdown('#### Img2Img Input')
|
||||
img2img_image_editor = gr.Image(value=sample_img2img, source="upload", interactive=True,
|
||||
type="pil", tool="select", elem_id="img2img_editor",
|
||||
image_mode="RGBA")
|
||||
img2img_image_mask = gr.Image(value=sample_img2img, source="upload", interactive=True,
|
||||
type="pil", tool="sketch", visible=False,
|
||||
elem_id="img2img_mask")
|
||||
|
||||
with gr.TabItem("Img2Img Mask Input"):
|
||||
img2img_mask_input = gr.Image(label="Mask",source="upload", interactive=False,
|
||||
type="pil", visible=True)
|
||||
gr.Markdown('#### Img2Img Input')
|
||||
img2img_image_editor = gr.Image(value=sample_img2img, source="upload", interactive=True,
|
||||
type="pil", tool="select", elem_id="img2img_editor", image_mode="RGBA"
|
||||
)
|
||||
img2img_image_mask = gr.Image(value=sample_img2img, source="upload", interactive=True,
|
||||
type="pil", tool="sketch", visible=False, image_mode="RGBA",
|
||||
elem_id="img2img_mask")
|
||||
|
||||
with gr.Tabs():
|
||||
with gr.TabItem("Editor Options"):
|
||||
with gr.Row():
|
||||
img2img_image_editor_mode = gr.Radio(choices=["Mask", "Crop", "Uncrop"], label="Image Editor Mode",
|
||||
value="Crop", elem_id='edit_mode_select')
|
||||
img2img_mask = gr.Radio(choices=["Keep masked area", "Regenerate only masked area", "Resize and regenerate only masked area"],
|
||||
img2img_mask = gr.Radio(choices=["Keep masked area", "Regenerate only masked area"],
|
||||
label="Mask Mode", type="index",
|
||||
value=img2img_mask_modes[img2img_defaults['mask_mode']], visible=False)
|
||||
value=img2img_mask_modes[img2img_defaults['mask_mode']], visible=False)
|
||||
|
||||
img2img_mask_blur_strength = gr.Slider(minimum=1, maximum=10, step=1,
|
||||
label="How much blurry should the mask be? (to avoid hard edges)",
|
||||
@ -263,16 +256,22 @@ def draw_gradio_ui(opt, img2img=lambda x: x, txt2img=lambda x: x,imgproc=lambda
|
||||
|
||||
img2img_image_editor_mode.change(
|
||||
uifn.change_image_editor_mode,
|
||||
[img2img_image_editor_mode, img2img_image_editor, img2img_resize, img2img_width, img2img_height],
|
||||
[img2img_image_editor_mode,
|
||||
img2img_image_editor,
|
||||
img2img_image_mask,
|
||||
img2img_resize,
|
||||
img2img_width,
|
||||
img2img_height
|
||||
],
|
||||
[img2img_image_editor, img2img_image_mask, img2img_btn_editor, img2img_btn_mask,
|
||||
img2img_painterro_btn, img2img_mask, img2img_mask_blur_strength, img2img_mask_input]
|
||||
img2img_painterro_btn, img2img_mask, img2img_mask_blur_strength]
|
||||
)
|
||||
|
||||
img2img_image_editor.edit(
|
||||
uifn.update_image_mask,
|
||||
[img2img_image_editor, img2img_resize, img2img_width, img2img_height],
|
||||
img2img_image_mask
|
||||
)
|
||||
# img2img_image_editor_mode.change(
|
||||
# uifn.update_image_mask,
|
||||
# [img2img_image_editor, img2img_resize, img2img_width, img2img_height],
|
||||
# img2img_image_mask
|
||||
# )
|
||||
|
||||
output_txt2img_copy_to_input_btn.click(
|
||||
uifn.copy_img_to_input,
|
||||
@ -306,11 +305,11 @@ def draw_gradio_ui(opt, img2img=lambda x: x, txt2img=lambda x: x,imgproc=lambda
|
||||
)
|
||||
|
||||
img2img_func = img2img
|
||||
img2img_inputs = [img2img_prompt, img2img_image_editor_mode, img2img_image_editor, img2img_image_mask, img2img_mask,
|
||||
img2img_inputs = [img2img_prompt, img2img_image_editor_mode, img2img_mask,
|
||||
img2img_mask_blur_strength, img2img_steps, img2img_sampling, img2img_toggles,
|
||||
img2img_realesrgan_model_name, img2img_batch_count, img2img_cfg,
|
||||
img2img_denoising, img2img_seed, img2img_height, img2img_width, img2img_resize,
|
||||
img2img_embeddings, img2img_mask_input]
|
||||
img2img_image_editor, img2img_image_mask, img2img_embeddings]
|
||||
img2img_outputs = [output_img2img_gallery, output_img2img_seed, output_img2img_params, output_img2img_stats]
|
||||
|
||||
# If a JobManager was passed in then wrap the Generate functions
|
||||
@ -321,33 +320,23 @@ def draw_gradio_ui(opt, img2img=lambda x: x, txt2img=lambda x: x,imgproc=lambda
|
||||
outputs=img2img_outputs,
|
||||
)
|
||||
|
||||
def generate(*args):
|
||||
args_list = list(args)
|
||||
init_info_mask = args_list[3]
|
||||
# Get the mask input and remove it from the list
|
||||
mask_input = args_list[18]
|
||||
del args_list[18]
|
||||
|
||||
# If an external mask is set, use it
|
||||
if mask_input:
|
||||
init_info_mask['mask'] = mask_input
|
||||
|
||||
args_list[3] = init_info_mask
|
||||
|
||||
# Return the result of img2img
|
||||
return img2img_func(*args_list)
|
||||
|
||||
img2img_btn_mask.click(
|
||||
generate,
|
||||
img2img_func,
|
||||
img2img_inputs,
|
||||
img2img_outputs
|
||||
)
|
||||
|
||||
img2img_btn_editor.click(
|
||||
img2img_func,
|
||||
def img2img_submit_params():
|
||||
#print([img2img_prompt, img2img_image_editor_mode, img2img_mask,
|
||||
# img2img_mask_blur_strength, img2img_steps, img2img_sampling, img2img_toggles,
|
||||
# img2img_realesrgan_model_name, img2img_batch_count, img2img_cfg,
|
||||
# img2img_denoising, img2img_seed, img2img_height, img2img_width, img2img_resize,
|
||||
# img2img_image_editor, img2img_image_mask, img2img_embeddings])
|
||||
return (img2img_func,
|
||||
img2img_inputs,
|
||||
img2img_outputs)
|
||||
|
||||
img2img_btn_editor.click(*img2img_submit_params())
|
||||
|
||||
# GENERATE ON ENTER
|
||||
img2img_prompt.submit(None, None, None,
|
||||
_js=call_JS("clickFirstVisibleButton",
|
||||
@ -374,7 +363,7 @@ def draw_gradio_ui(opt, img2img=lambda x: x, txt2img=lambda x: x,imgproc=lambda
|
||||
# value=gfpgan_defaults['strength'])
|
||||
#select folder with images to process
|
||||
with gr.TabItem('Batch Process'):
|
||||
imgproc_folder = gr.File(label="Batch Process", file_count="multiple", interactive=True, type="file")
|
||||
imgproc_folder = gr.File(label="Batch Process", file_count="multiple",source="upload", interactive=True, type="file")
|
||||
imgproc_pngnfo = gr.Textbox(label="PNG Metadata", placeholder="PngNfo", visible=False, max_lines=5)
|
||||
with gr.Row():
|
||||
imgproc_btn = gr.Button("Process", variant="primary")
|
||||
@ -580,7 +569,7 @@ def draw_gradio_ui(opt, img2img=lambda x: x, txt2img=lambda x: x,imgproc=lambda
|
||||
<div id="90" style="max-width: 100%; font-size: 14px; text-align: center;" class="output-markdown gr-prose border-solid border border-gray-200 rounded gr-panel">
|
||||
<p>For help and advanced usage guides, visit the <a href="https://github.com/hlky/stable-diffusion-webui/wiki" target="_blank">Project Wiki</a></p>
|
||||
<p>Stable Diffusion WebUI is an open-source project. You can find the latest stable builds on the <a href="https://github.com/hlky/stable-diffusion" target="_blank">main repository</a>.
|
||||
If you would like to contribute to development or test bleeding edge builds, you can visit the <a href="https://github.com/hlky/stable-diffusion-webui" target="_blank">development repository</a>.</p>
|
||||
If you would like to contribute to development or test bleeding edge builds, you can visit the <a href="https://github.com/hlky/stable-diffusion-webui" target="_blank">developement repository</a>.</p>
|
||||
</div>
|
||||
""")
|
||||
# Hack: Detect the load event on the frontend
|
||||
|
@ -1,7 +1,7 @@
|
||||
''' Provides simple job management for gradio, allowing viewing and stopping in-progress multi-batch generations '''
|
||||
from __future__ import annotations
|
||||
import gradio as gr
|
||||
from gradio.components import Component, Gallery, Slider
|
||||
from gradio.components import Component, Gallery
|
||||
from threading import Event, Timer
|
||||
from typing import Callable, List, Dict, Tuple, Optional, Any
|
||||
from dataclasses import dataclass, field
|
||||
@ -30,17 +30,7 @@ class JobInfo:
|
||||
session_key: str
|
||||
job_token: Optional[int] = None
|
||||
images: List[Image] = field(default_factory=list)
|
||||
active_image: Image = None
|
||||
rec_steps_enabled: bool = False
|
||||
rec_steps_imgs: List[Image] = field(default_factory=list)
|
||||
rec_steps_intrvl: int = None
|
||||
rec_steps_to_gallery: bool = False
|
||||
rec_steps_to_file: bool = False
|
||||
should_stop: Event = field(default_factory=Event)
|
||||
refresh_active_image_requested: Event = field(default_factory=Event)
|
||||
refresh_active_image_done: Event = field(default_factory=Event)
|
||||
stop_cur_iter: Event = field(default_factory=Event)
|
||||
active_iteration_cnt: int = field(default_factory=int)
|
||||
job_status: str = field(default_factory=str)
|
||||
finished: bool = False
|
||||
removed_output_idxs: List[int] = field(default_factory=list)
|
||||
@ -86,7 +76,7 @@ class JobManagerUi:
|
||||
'''
|
||||
return self._job_manager._wrap_func(
|
||||
func=func, inputs=inputs, outputs=outputs,
|
||||
job_ui=self
|
||||
refresh_btn=self._refresh_btn, stop_btn=self._stop_btn, status_text=self._status_text
|
||||
)
|
||||
|
||||
_refresh_btn: gr.Button
|
||||
@ -94,13 +84,6 @@ class JobManagerUi:
|
||||
_status_text: gr.Textbox
|
||||
_stop_all_session_btn: gr.Button
|
||||
_free_done_sessions_btn: gr.Button
|
||||
_active_image: gr.Image
|
||||
_active_image_stop_btn: gr.Button
|
||||
_active_image_refresh_btn: gr.Button
|
||||
_rec_steps_intrvl_sldr: gr.Slider
|
||||
_rec_steps_checkbox: gr.Checkbox
|
||||
_save_rec_steps_to_gallery_chkbx: gr.Checkbox
|
||||
_save_rec_steps_to_file_chkbx: gr.Checkbox
|
||||
_job_manager: JobManager
|
||||
|
||||
|
||||
@ -119,23 +102,11 @@ class JobManager:
|
||||
'''
|
||||
assert gr.context.Context.block is not None, "draw_gradio_ui must be called within a 'gr.Blocks' 'with' context"
|
||||
with gr.Tabs():
|
||||
with gr.TabItem("Job Controls"):
|
||||
with gr.TabItem("Current Session"):
|
||||
with gr.Row():
|
||||
stop_btn = gr.Button("Stop All Batches", elem_id="stop", variant="secondary")
|
||||
refresh_btn = gr.Button("Refresh Finished Batches", elem_id="refresh", variant="secondary")
|
||||
stop_btn = gr.Button("Stop", elem_id="stop", variant="secondary")
|
||||
refresh_btn = gr.Button("Refresh", elem_id="refresh", variant="secondary")
|
||||
status_text = gr.Textbox(placeholder="Job Status", interactive=False, show_label=False)
|
||||
with gr.Row():
|
||||
active_image_stop_btn = gr.Button("Skip Active Batch", variant="secondary")
|
||||
active_image_refresh_btn = gr.Button("View Batch Progress", variant="secondary")
|
||||
active_image = gr.Image(type="pil", interactive=False, visible=False, elem_id="active_iteration_image")
|
||||
with gr.TabItem("Batch Progress Settings"):
|
||||
with gr.Row():
|
||||
record_steps_checkbox = gr.Checkbox(value=False, label="Enable Batch Progress Grid")
|
||||
record_steps_interval_slider = gr.Slider(
|
||||
value=3, label="Record Interval (steps)", minimum=1, maximum=25, step=1)
|
||||
with gr.Row() as record_steps_box:
|
||||
steps_to_gallery_checkbox = gr.Checkbox(value=False, label="Save Progress Grid to Gallery")
|
||||
steps_to_file_checkbox = gr.Checkbox(value=False, label="Save Progress Grid to File")
|
||||
with gr.TabItem("Maintenance"):
|
||||
with gr.Row():
|
||||
gr.Markdown(
|
||||
@ -147,15 +118,9 @@ class JobManager:
|
||||
free_done_sessions_btn = gr.Button(
|
||||
"Clear Finished Jobs", elem_id="clear_finished", variant="secondary"
|
||||
)
|
||||
|
||||
return JobManagerUi(_refresh_btn=refresh_btn, _stop_btn=stop_btn, _status_text=status_text,
|
||||
_stop_all_session_btn=stop_all_sessions_btn, _free_done_sessions_btn=free_done_sessions_btn,
|
||||
_active_image=active_image, _active_image_stop_btn=active_image_stop_btn,
|
||||
_active_image_refresh_btn=active_image_refresh_btn,
|
||||
_rec_steps_checkbox=record_steps_checkbox,
|
||||
_save_rec_steps_to_gallery_chkbx=steps_to_gallery_checkbox,
|
||||
_save_rec_steps_to_file_chkbx=steps_to_file_checkbox,
|
||||
_rec_steps_intrvl_sldr=record_steps_interval_slider, _job_manager=self)
|
||||
_job_manager=self)
|
||||
|
||||
def clear_all_finished_jobs(self):
|
||||
''' Removes all currently finished jobs, across all sessions.
|
||||
@ -169,7 +134,6 @@ class JobManager:
|
||||
for session in self._sessions.values():
|
||||
for job in session.jobs.values():
|
||||
job.should_stop.set()
|
||||
job.stop_cur_iter.set()
|
||||
|
||||
def _get_job_token(self, block: bool = False) -> Optional[int]:
|
||||
''' Attempts to acquire a job token, optionally blocking until available '''
|
||||
@ -211,26 +175,6 @@ class JobManager:
|
||||
job_info.should_stop.set()
|
||||
return "Stopping after current batch finishes"
|
||||
|
||||
def _refresh_cur_iter_func(self, func_key: FuncKey, session_key: str) -> List[Component]:
|
||||
''' Updates information from the active iteration '''
|
||||
session_info, job_info = self._get_call_info(func_key, session_key)
|
||||
if job_info is None:
|
||||
return [None, f"Session {session_key} was not running function {func_key}"]
|
||||
|
||||
job_info.refresh_active_image_requested.set()
|
||||
if job_info.refresh_active_image_done.wait(timeout=20.0):
|
||||
job_info.refresh_active_image_done.clear()
|
||||
return [gr.Image.update(value=job_info.active_image, visible=True), f"Sample iteration {job_info.active_iteration_cnt}"]
|
||||
return [gr.Image.update(visible=False), "Timed out getting image"]
|
||||
|
||||
def _stop_cur_iter_func(self, func_key: FuncKey, session_key: str) -> List[Component]:
|
||||
''' Marks that the active iteration should be stopped'''
|
||||
session_info, job_info = self._get_call_info(func_key, session_key)
|
||||
if job_info is None:
|
||||
return [None, f"Session {session_key} was not running function {func_key}"]
|
||||
job_info.stop_cur_iter.set()
|
||||
return [gr.Image.update(visible=False), "Stopping current iteration"]
|
||||
|
||||
def _get_call_info(self, func_key: FuncKey, session_key: str) -> Tuple[SessionInfo, JobInfo]:
|
||||
''' Helper to get the SessionInfo and JobInfo. '''
|
||||
session_info = self._sessions.get(session_key, None)
|
||||
@ -263,8 +207,7 @@ class JobManager:
|
||||
|
||||
def _pre_call_func(
|
||||
self, func_key: FuncKey, output_dummy_obj: Component, refresh_btn: gr.Button, stop_btn: gr.Button,
|
||||
status_text: gr.Textbox, active_image: gr.Image, active_refresh_btn: gr.Button, active_stop_btn: gr.Button,
|
||||
session_key: str) -> List[Component]:
|
||||
status_text: gr.Textbox, session_key: str) -> List[Component]:
|
||||
''' Called when a job is about to start '''
|
||||
session_info, job_info = self._get_call_info(func_key, session_key)
|
||||
|
||||
@ -276,9 +219,7 @@ class JobManager:
|
||||
return {output_dummy_obj: triggerChangeEvent(),
|
||||
refresh_btn: gr.Button.update(variant="primary", value=refresh_btn.value),
|
||||
stop_btn: gr.Button.update(variant="primary", value=stop_btn.value),
|
||||
status_text: gr.Textbox.update(value="Generation has started. Click 'Refresh' to see finished images, 'View Batch Progress' for active images"),
|
||||
active_refresh_btn: gr.Button.update(variant="primary", value=active_refresh_btn.value),
|
||||
active_stop_btn: gr.Button.update(variant="primary", value=active_stop_btn.value),
|
||||
status_text: gr.Textbox.update(value="Generation has started. Click 'Refresh' for updates")
|
||||
}
|
||||
|
||||
def _call_func(self, func_key: FuncKey, session_key: str) -> List[Component]:
|
||||
@ -292,7 +233,7 @@ class JobManager:
|
||||
except Exception as e:
|
||||
job_info.job_status = f"Error: {e}"
|
||||
print(f"Exception processing job {job_info}: {e}\n{traceback.format_exc()}")
|
||||
raise
|
||||
outputs = []
|
||||
|
||||
# Filter the function output for any removed outputs
|
||||
filtered_output = []
|
||||
@ -313,16 +254,12 @@ class JobManager:
|
||||
|
||||
def _post_call_func(
|
||||
self, func_key: FuncKey, output_dummy_obj: Component, refresh_btn: gr.Button, stop_btn: gr.Button,
|
||||
status_text: gr.Textbox, active_image: gr.Image, active_refresh_btn: gr.Button, active_stop_btn: gr.Button,
|
||||
session_key: str) -> List[Component]:
|
||||
status_text: gr.Textbox, session_key: str) -> List[Component]:
|
||||
''' Called when a job completes '''
|
||||
return {output_dummy_obj: triggerChangeEvent(),
|
||||
refresh_btn: gr.Button.update(variant="secondary", value=refresh_btn.value),
|
||||
stop_btn: gr.Button.update(variant="secondary", value=stop_btn.value),
|
||||
status_text: gr.Textbox.update(value="Generation has finished!"),
|
||||
active_refresh_btn: gr.Button.update(variant="secondary", value=active_refresh_btn.value),
|
||||
active_stop_btn: gr.Button.update(variant="secondary", value=active_stop_btn.value),
|
||||
active_image: gr.Image.update(visible=False)
|
||||
status_text: gr.Textbox.update(value="Generation has finished!")
|
||||
}
|
||||
|
||||
def _update_gallery_event(self, func_key: FuncKey, session_key: str) -> List[Component]:
|
||||
@ -338,15 +275,16 @@ class JobManager:
|
||||
|
||||
return job_info.images
|
||||
|
||||
def _wrap_func(self, func: Callable, inputs: List[Component],
|
||||
outputs: List[Component],
|
||||
job_ui: JobManagerUi) -> Tuple[Callable, List[Component]]:
|
||||
def _wrap_func(
|
||||
self, func: Callable, inputs: List[Component], outputs: List[Component],
|
||||
refresh_btn: gr.Button = None, stop_btn: gr.Button = None,
|
||||
status_text: Optional[gr.Textbox] = None) -> Tuple[Callable, List[Component]]:
|
||||
''' handles JobManageUI's wrap_func'''
|
||||
|
||||
assert gr.context.Context.block is not None, "wrap_func must be called within a 'gr.Blocks' 'with' context"
|
||||
|
||||
# Create a unique key for this job
|
||||
func_key = FuncKey(job_id=uuid.uuid4().hex, func=func)
|
||||
func_key = FuncKey(job_id=uuid.uuid4(), func=func)
|
||||
|
||||
# Create a unique session key (next gradio release can use gr.State, see https://gradio.app/state_in_blocks/)
|
||||
if self._session_key is None:
|
||||
@ -364,6 +302,9 @@ class JobManager:
|
||||
del outputs[idx]
|
||||
break
|
||||
|
||||
# Add the session key to the inputs
|
||||
inputs += [self._session_key]
|
||||
|
||||
# Create dummy objects
|
||||
update_gallery_obj = gr.JSON(visible=False, elem_id="JobManagerDummyObject")
|
||||
update_gallery_obj.change(
|
||||
@ -372,44 +313,20 @@ class JobManager:
|
||||
[gallery_comp]
|
||||
)
|
||||
|
||||
if job_ui._refresh_btn:
|
||||
job_ui._refresh_btn.variant = 'secondary'
|
||||
job_ui._refresh_btn.click(
|
||||
if refresh_btn:
|
||||
refresh_btn.variant = 'secondary'
|
||||
refresh_btn.click(
|
||||
partial(self._refresh_func, func_key),
|
||||
[self._session_key],
|
||||
[update_gallery_obj, job_ui._status_text]
|
||||
[update_gallery_obj, status_text]
|
||||
)
|
||||
|
||||
if job_ui._stop_btn:
|
||||
job_ui._stop_btn.variant = 'secondary'
|
||||
job_ui._stop_btn.click(
|
||||
if stop_btn:
|
||||
stop_btn.variant = 'secondary'
|
||||
stop_btn.click(
|
||||
partial(self._stop_wrapped_func, func_key),
|
||||
[self._session_key],
|
||||
[job_ui._status_text]
|
||||
)
|
||||
|
||||
if job_ui._active_image and job_ui._active_image_refresh_btn:
|
||||
job_ui._active_image_refresh_btn.click(
|
||||
partial(self._refresh_cur_iter_func, func_key),
|
||||
[self._session_key],
|
||||
[job_ui._active_image, job_ui._status_text]
|
||||
)
|
||||
|
||||
if job_ui._active_image_stop_btn:
|
||||
job_ui._active_image_stop_btn.click(
|
||||
partial(self._stop_cur_iter_func, func_key),
|
||||
[self._session_key],
|
||||
[job_ui._active_image, job_ui._status_text]
|
||||
)
|
||||
|
||||
if job_ui._stop_all_session_btn:
|
||||
job_ui._stop_all_session_btn.click(
|
||||
self.stop_all_jobs, [], []
|
||||
)
|
||||
|
||||
if job_ui._free_done_sessions_btn:
|
||||
job_ui._free_done_sessions_btn.click(
|
||||
self.clear_all_finished_jobs, [], []
|
||||
[status_text]
|
||||
)
|
||||
|
||||
# (ab)use gr.JSON to forward events.
|
||||
@ -426,8 +343,7 @@ class JobManager:
|
||||
# Since some parameters are optional it makes sense to use the 'dict' return value type, which requires
|
||||
# the Component as a key... so group together the UI components that the event listeners are going to update
|
||||
# to make it easy to append to function calls and outputs
|
||||
job_ui_params = [job_ui._refresh_btn, job_ui._stop_btn, job_ui._status_text,
|
||||
job_ui._active_image, job_ui._active_image_refresh_btn, job_ui._active_image_stop_btn]
|
||||
job_ui_params = [refresh_btn, stop_btn, status_text]
|
||||
job_ui_outputs = [comp for comp in job_ui_params if comp is not None]
|
||||
|
||||
# Here a chain is constructed that will make a 'pre' call, a 'run' call, and a 'post' call,
|
||||
@ -453,39 +369,27 @@ class JobManager:
|
||||
[call_dummyobj] + job_ui_outputs
|
||||
)
|
||||
|
||||
# Add any components that we want the runtime values for
|
||||
added_inputs = [self._session_key, job_ui._rec_steps_checkbox, job_ui._save_rec_steps_to_gallery_chkbx,
|
||||
job_ui._save_rec_steps_to_file_chkbx, job_ui._rec_steps_intrvl_sldr]
|
||||
|
||||
# Now replace the original function with one that creates a JobInfo and triggers the dummy obj
|
||||
def wrapped_func(*wrapped_inputs):
|
||||
# Remove the added_inputs (pop opposite order of list)
|
||||
|
||||
wrapped_inputs = list(wrapped_inputs)
|
||||
rec_steps_interval: int = wrapped_inputs.pop()
|
||||
save_rec_steps_file: bool = wrapped_inputs.pop()
|
||||
save_rec_steps_grid: bool = wrapped_inputs.pop()
|
||||
record_steps_enabled: bool = wrapped_inputs.pop()
|
||||
session_key: str = wrapped_inputs.pop()
|
||||
job_inputs = tuple(wrapped_inputs)
|
||||
def wrapped_func(*inputs):
|
||||
session_key = inputs[-1]
|
||||
inputs = inputs[:-1]
|
||||
|
||||
# Get or create a session for this key
|
||||
session_info = self._sessions.setdefault(session_key, SessionInfo())
|
||||
|
||||
# Is this session already running this job?
|
||||
if func_key in session_info.jobs:
|
||||
return {job_ui._status_text: "This session is already running that function!"}
|
||||
return {status_text: "This session is already running that function!"}
|
||||
|
||||
job_token = self._get_job_token(block=False)
|
||||
job = JobInfo(
|
||||
inputs=job_inputs, func=func, removed_output_idxs=removed_idxs, session_key=session_key,
|
||||
job_token=job_token, rec_steps_enabled=record_steps_enabled, rec_steps_intrvl=rec_steps_interval,
|
||||
rec_steps_to_gallery=save_rec_steps_grid, rec_steps_to_file=save_rec_steps_file)
|
||||
job = JobInfo(inputs=inputs, func=func, removed_output_idxs=removed_idxs, session_key=session_key,
|
||||
job_token=job_token)
|
||||
session_info.jobs[func_key] = job
|
||||
|
||||
ret = {pre_call_dummyobj: triggerChangeEvent()}
|
||||
if job_token is None:
|
||||
ret[job_ui._status_text] = "Job is queued"
|
||||
ret[status_text] = "Job is queued"
|
||||
return ret
|
||||
|
||||
return wrapped_func, inputs + added_inputs, [pre_call_dummyobj, job_ui._status_text]
|
||||
return wrapped_func, inputs, [pre_call_dummyobj, status_text]
|
||||
|
@ -6,33 +6,17 @@ import base64
|
||||
import re
|
||||
|
||||
|
||||
def change_image_editor_mode(choice, cropped_image, resize_mode, width, height):
|
||||
def change_image_editor_mode(choice, cropped_image, masked_image, resize_mode, width, height):
|
||||
if choice == "Mask":
|
||||
return [gr.Image.update(visible=False),
|
||||
gr.Image.update(visible=True),
|
||||
gr.Button.update("Generate", variant="primary", visible=False),
|
||||
gr.Button.update("Generate", variant="primary", visible=True),
|
||||
gr.Button.update("Advanced Editor", visible=False),
|
||||
gr.Radio.update(choices=["Keep masked area", "Regenerate only masked area"],
|
||||
label="Mask Mode",
|
||||
value="Regenerate only masked area", visible=True),
|
||||
gr.Slider.update(minimum=1, maximum=10, step=1, label="How much blurry should the mask be? (to avoid hard edges)", value=3, visible=True),
|
||||
gr.Image.update(interactive=True)]
|
||||
else:
|
||||
return [gr.Image.update(visible=True),
|
||||
gr.Image.update(visible=False),
|
||||
gr.Button.update("Generate", variant="primary", visible=True),
|
||||
gr.Button.update("Generate", variant="primary", visible=False),
|
||||
gr.Button.update("Advanced Editor", visible=True),
|
||||
gr.Radio.update(choices=["Keep masked area", "Regenerate only masked area"],
|
||||
label="Mask Mode",
|
||||
value="Regenerate only masked area", visible=False),
|
||||
gr.Slider.update(minimum=1, maximum=10, step=1, label="How much blurry should the mask be? (to avoid hard edges)", value=3, visible=False),
|
||||
gr.Image.update(interactive=False)]
|
||||
update_image_result = update_image_mask(cropped_image, resize_mode, width, height)
|
||||
return [gr.update(visible=False), update_image_result, gr.update(visible=False), gr.update(visible=True), gr.update(visible=False), gr.update(visible=True), gr.update(visible=True)]
|
||||
|
||||
update_image_result = update_image_mask(masked_image["image"], resize_mode, width, height)
|
||||
return [update_image_result, gr.update(visible=False), gr.update(visible=True), gr.update(visible=False), gr.update(visible=True), gr.update(visible=False), gr.update(visible=False)]
|
||||
|
||||
def update_image_mask(cropped_image, resize_mode, width, height):
|
||||
resized_cropped_image = resize_image(resize_mode, cropped_image, width, height) if cropped_image else None
|
||||
return gr.Image.update(value=resized_cropped_image)
|
||||
return gr.update(value=resized_cropped_image, visible=True)
|
||||
|
||||
def toggle_options_gfpgan(selection):
|
||||
if 0 in selection:
|
||||
|
356
scripts/webui.py
356
scripts/webui.py
@ -1,5 +1,7 @@
|
||||
import argparse, os, sys, glob, re
|
||||
|
||||
import cv2
|
||||
|
||||
from frontend.frontend import draw_gradio_ui
|
||||
from frontend.job_manager import JobManager, JobInfo
|
||||
from frontend.ui_functions import resize_image
|
||||
@ -37,11 +39,9 @@ parser.add_argument("--save-metadata", action='store_true', help="Store generati
|
||||
parser.add_argument("--share-password", type=str, help="Sharing is open by default, use this to set a password. Username: webui", default=None)
|
||||
parser.add_argument("--share", action='store_true', help="Should share your server on gradio.app, this allows you to use the UI from your mobile app", default=False)
|
||||
parser.add_argument("--skip-grid", action='store_true', help="do not save a grid, only individual samples. Helpful when evaluating lots of samples", default=False)
|
||||
parser.add_argument("--save-each", action='store_true', help="save individual samples. For speed measurements.", default=False)
|
||||
parser.add_argument("--skip-save", action='store_true', help="do not save indiviual samples. For speed measurements.", default=False)
|
||||
parser.add_argument('--no-job-manager', action='store_true', help="Don't use the experimental job manager on top of gradio", default=False)
|
||||
parser.add_argument("--max-jobs", type=int, help="Maximum number of concurrent 'generate' commands", default=1)
|
||||
parser.add_argument("--custom-css", action='store_true', help="Place custom.css in css folder to load a custom theme of the UI", default=False)
|
||||
|
||||
opt = parser.parse_args()
|
||||
|
||||
#Should not be needed anymore
|
||||
@ -66,12 +66,9 @@ import torch
|
||||
import torch.nn as nn
|
||||
import yaml
|
||||
import glob
|
||||
import copy
|
||||
from typing import List, Union, Dict, Callable, Any
|
||||
from typing import List, Union, Dict
|
||||
from pathlib import Path
|
||||
from collections import namedtuple
|
||||
import cv2
|
||||
from functools import partial
|
||||
|
||||
from contextlib import contextmanager, nullcontext
|
||||
from einops import rearrange, repeat
|
||||
@ -109,7 +106,6 @@ invalid_filename_chars = '<>:"/\|?*\n'
|
||||
GFPGAN_dir = opt.gfpgan_dir
|
||||
RealESRGAN_dir = opt.realesrgan_dir
|
||||
LDSR_dir = opt.ldsr_dir
|
||||
returned_info = {}
|
||||
|
||||
if opt.optimized_turbo:
|
||||
opt.optimized = True
|
||||
@ -140,13 +136,6 @@ elif grid_format[0] == 'webp':
|
||||
grid_quality = abs(grid_quality)
|
||||
|
||||
|
||||
def toImgOpenCV(imgPIL): # Conver imgPIL to imgOpenCV
|
||||
i = np.array(imgPIL) # After mapping from PIL to numpy : [R,G,B,A]
|
||||
# numpy Image Channel system: [B,G,R,A]
|
||||
red = i[:,:,0].copy(); i[:,:,0] = i[:,:,2].copy(); i[:,:,2] = red
|
||||
return i
|
||||
def toImgPIL(imgOpenCV): return Image.fromarray(cv2.cvtColor(imgOpenCV, cv2.COLOR_BGR2RGB))
|
||||
|
||||
def chunk(it, size):
|
||||
it = iter(it)
|
||||
return iter(lambda: tuple(islice(it, size)), ())
|
||||
@ -275,21 +264,15 @@ class KDiffusionSampler:
|
||||
self.schedule = sampler
|
||||
def get_sampler_name(self):
|
||||
return self.schedule
|
||||
def sample(self, S, conditioning, batch_size, shape, verbose, unconditional_guidance_scale, unconditional_conditioning, eta, x_T, img_callback: Callable = None ):
|
||||
def sample(self, S, conditioning, batch_size, shape, verbose, unconditional_guidance_scale, unconditional_conditioning, eta, x_T):
|
||||
sigmas = self.model_wrap.get_sigmas(S)
|
||||
x = x_T * sigmas[0]
|
||||
model_wrap_cfg = CFGDenoiser(self.model_wrap)
|
||||
samples_ddim = K.sampling.__dict__[f'sample_{self.schedule}'](model_wrap_cfg, x, sigmas, extra_args={'cond': conditioning, 'uncond': unconditional_conditioning, 'cond_scale': unconditional_guidance_scale}, disable=False, callback=partial(KDiffusionSampler.img_callback_wrapper, img_callback))
|
||||
|
||||
samples_ddim = K.sampling.__dict__[f'sample_{self.schedule}'](model_wrap_cfg, x, sigmas, extra_args={'cond': conditioning, 'uncond': unconditional_conditioning, 'cond_scale': unconditional_guidance_scale}, disable=False)
|
||||
|
||||
return samples_ddim, None
|
||||
|
||||
@classmethod
|
||||
def img_callback_wrapper(cls, callback: Callable, *args):
|
||||
''' Converts a KDiffusion callback to the standard img_callback '''
|
||||
if callback:
|
||||
arg_dict = args[0]
|
||||
callback(image_sample=arg_dict['denoised'], iter_num=arg_dict['i'])
|
||||
|
||||
|
||||
def create_random_tensors(shape, seeds):
|
||||
xs = []
|
||||
@ -524,7 +507,6 @@ def seed_to_int(s):
|
||||
n = n >> 32
|
||||
return n
|
||||
|
||||
|
||||
def draw_prompt_matrix(im, width, height, all_prompts):
|
||||
def wrap(text, d, font, line_length):
|
||||
lines = ['']
|
||||
@ -590,63 +572,6 @@ def draw_prompt_matrix(im, width, height, all_prompts):
|
||||
return result
|
||||
|
||||
|
||||
def round_to_multiple(dimension, dimension_ceiling, multiple=64, round_down=True):
|
||||
if round_down:
|
||||
rounded_dimension = multiple * math.ceil(dimension / multiple)
|
||||
else:
|
||||
rounded_dimension = multiple * math.floor(dimension / multiple)
|
||||
return rounded_dimension
|
||||
|
||||
|
||||
def crop_image(img, mask, width, height):
|
||||
def get_mask_and_img(img, mask,dimension, coords, target_width, target_height):
|
||||
longest_target_dimension = round_to_multiple(dimension, dimension)
|
||||
func_crop_coords = (coords[0], coords[1], coords[0]+longest_target_dimension, coords[1]+longest_target_dimension)
|
||||
resized_img = img.crop(func_crop_coords)
|
||||
scale_dimension = target_width if target_width > target_height else target_height
|
||||
resized_img = resized_img.resize((scale_dimension, scale_dimension), resample=Image.Resampling.LANCZOS)
|
||||
|
||||
resized_mask = mask.crop(func_crop_coords)
|
||||
cropped_img_width, cropped_img_height = resized_mask.size
|
||||
resized_mask = resized_mask.resize((scale_dimension, scale_dimension), resample=Image.Resampling.LANCZOS)
|
||||
|
||||
alpha_mask = resized_mask.convert("RGBA")
|
||||
mask_data = alpha_mask.getdata()
|
||||
container = []
|
||||
for item in mask_data:
|
||||
if item[0] == 0 and item[1] == 0 and item[2] == 0:
|
||||
container.append((255, 255, 255, 0))
|
||||
else:
|
||||
container.append(item)
|
||||
alpha_mask.putdata(container)
|
||||
|
||||
results = {
|
||||
"cropped_img": resized_img,
|
||||
"org_img": rgb_image,
|
||||
"cropped_mask": alpha_mask,
|
||||
"coords": crop_coords,
|
||||
"scale_width": width,
|
||||
"scale_height": height,
|
||||
"org_width": cropped_img_width,
|
||||
"org_height": cropped_img_height
|
||||
}
|
||||
return results
|
||||
|
||||
rgb_image = img.convert("RGB")
|
||||
rgb_mask = mask.convert("RGB")
|
||||
np_mask = np.array(rgb_mask)
|
||||
white_columns = np.where(np_mask.max(axis=0)>= 255)[0]
|
||||
white_rows = np.where(np_mask.max(axis=1)>= 255)[0]
|
||||
crop_coords = (min(white_columns), min(white_rows), max(white_columns), max(white_rows))
|
||||
crop_to_size = rgb_image.crop(crop_coords)
|
||||
cropped_img_width, cropped_img_height = crop_to_size.size
|
||||
|
||||
if cropped_img_width > cropped_img_height:
|
||||
results_dict = get_mask_and_img(rgb_image, mask, cropped_img_width, crop_coords, width, height)
|
||||
else:
|
||||
results_dict = get_mask_and_img(rgb_image, mask, cropped_img_height, crop_coords, width, height)
|
||||
|
||||
return results_dict
|
||||
|
||||
|
||||
def check_prompt_length(prompt, comments):
|
||||
@ -668,8 +593,8 @@ def check_prompt_length(prompt, comments):
|
||||
comments.append(f"Warning: too many input tokens; some ({len(overflowing_words)}) have been truncated:\n{overflowing_text}\n")
|
||||
|
||||
def save_sample(image, sample_path_i, filename, jpg_sample, prompts, seeds, width, height, steps, cfg_scale,
|
||||
normalize_prompt_weights, use_GFPGAN, write_info_files, write_sample_info_to_log_file, prompt_matrix, init_img, uses_loopback, uses_random_seed_loopback, save_each,
|
||||
skip_grid, sort_samples, sampler_name, ddim_eta, n_iter, batch_size, i, denoising_strength, resize_mode, skip_metadata=True):
|
||||
normalize_prompt_weights, use_GFPGAN, write_info_files, write_sample_info_to_log_file, prompt_matrix, init_img, uses_loopback, uses_random_seed_loopback, skip_save,
|
||||
skip_grid, sort_samples, sampler_name, ddim_eta, n_iter, batch_size, i, denoising_strength, resize_mode, skip_metadata=True):
|
||||
filename_i = os.path.join(sample_path_i, filename)
|
||||
if not jpg_sample:
|
||||
if opt.save_metadata and not skip_metadata:
|
||||
@ -702,7 +627,7 @@ def save_sample(image, sample_path_i, filename, jpg_sample, prompts, seeds, widt
|
||||
toggles.append(2)
|
||||
if uses_random_seed_loopback:
|
||||
toggles.append(3)
|
||||
if save_each:
|
||||
if not skip_save:
|
||||
toggles.append(2 + offset)
|
||||
if not skip_grid:
|
||||
toggles.append(3 + offset)
|
||||
@ -852,12 +777,12 @@ def oxlamon_matrix(prompt, seed, n_iter, batch_size):
|
||||
|
||||
|
||||
def process_images(
|
||||
outpath, func_init, func_sample, prompt, seed, sampler_name, skip_grid, save_each, batch_size,
|
||||
outpath, func_init, func_sample, prompt, seed, sampler_name, skip_grid, skip_save, batch_size,
|
||||
n_iter, steps, cfg_scale, width, height, prompt_matrix, use_GFPGAN, use_RealESRGAN, realesrgan_model_name,
|
||||
fp, ddim_eta=0.0, do_not_save_grid=False, normalize_prompt_weights=True, init_img=None, init_mask=None,
|
||||
keep_mask=False, mask_blur_strength=3, denoising_strength=0.75, resize_mode=None, uses_loopback=False,
|
||||
uses_random_seed_loopback=False, sort_samples=True, write_info_files=True, write_sample_info_to_log_file=False, jpg_sample=False,
|
||||
variant_amount=0.0, variant_seed=None,imgProcessorTask=False,resize_mask=False, job_info: JobInfo = None):
|
||||
variant_amount=0.0, variant_seed=None,imgProcessorTask=False, job_info: JobInfo = None):
|
||||
"""this is the main loop that both txt2img and img2img use; it calls func_init once inside all the scopes and func_sample once per batch"""
|
||||
prompt = prompt or ''
|
||||
torch_gc()
|
||||
@ -956,7 +881,6 @@ def process_images(
|
||||
|
||||
if job_info:
|
||||
job_info.job_status = f"Processing Iteration {n+1}/{n_iter}. Batch size {batch_size}"
|
||||
job_info.rec_steps_imgs.clear()
|
||||
for idx,(p,s) in enumerate(zip(prompts,seeds)):
|
||||
job_info.job_status += f"\nItem {idx}: Seed {s}\nPrompt: {p}"
|
||||
|
||||
@ -987,7 +911,7 @@ def process_images(
|
||||
while(torch.cuda.memory_allocated()/1e6 >= mem):
|
||||
time.sleep(1)
|
||||
|
||||
cur_variant_amount = variant_amount
|
||||
cur_variant_amount = variant_amount
|
||||
if variant_amount == 0.0:
|
||||
# we manually generate all input noises because each one should have a specific seed
|
||||
x = create_random_tensors(shape, seeds=seeds)
|
||||
@ -1010,78 +934,17 @@ def process_images(
|
||||
# finally, slerp base_x noise to target_x noise for creating a variant
|
||||
x = slerp(device, max(0.0, min(1.0, cur_variant_amount)), base_x, target_x)
|
||||
|
||||
|
||||
# If in optimized mode then make a CPU-copy of the model to generate preview images
|
||||
if opt.optimized:
|
||||
step_preview_model = copy.deepcopy(modelFS).to("cpu")
|
||||
if not opt.no_half:
|
||||
step_preview_model.float()
|
||||
else:
|
||||
step_preview_model = model
|
||||
|
||||
def sample_iteration_callback(image_sample: torch.Tensor, iter_num: int):
|
||||
''' Called from the sampler every iteration '''
|
||||
if job_info:
|
||||
job_info.active_iteration_cnt = iter_num
|
||||
record_periodic_image = job_info.rec_steps_enabled and (0 == iter_num % job_info.rec_steps_intrvl)
|
||||
if record_periodic_image or job_info.refresh_active_image_requested.is_set():
|
||||
preview_start_time = time.time()
|
||||
if opt.optimized:
|
||||
image_sample = image_sample.to("cpu")
|
||||
|
||||
batch_ddim = step_preview_model.decode_first_stage(image_sample)
|
||||
batch_ddim = torch.clamp((batch_ddim + 1.0) / 2.0, min=0.0, max=1.0)
|
||||
preview_elapsed_timed = time.time() - preview_start_time
|
||||
|
||||
if preview_elapsed_timed > 1:
|
||||
print(
|
||||
f"Warning: Preview generation is slow! It took {preview_elapsed_timed:.2f}s to generate one preview!")
|
||||
|
||||
images: List[Image.Image] = []
|
||||
# Convert tensor to image (copied from code below)
|
||||
for ddim in batch_ddim:
|
||||
x_sample = 255. * rearrange(ddim.cpu().numpy(), 'c h w -> h w c')
|
||||
x_sample = x_sample.astype(np.uint8)
|
||||
image = Image.fromarray(x_sample)
|
||||
images.append(image)
|
||||
|
||||
caption = f"Iter {iter_num}"
|
||||
grid = image_grid(images, len(images), force_n_rows=1, captions=[caption]*len(images))
|
||||
|
||||
# Save the images if recording steps, and append existing saved steps
|
||||
if job_info.rec_steps_enabled:
|
||||
gallery_img_size = tuple( int(0.25*dim) for dim in images[0].size)
|
||||
job_info.rec_steps_imgs.append(grid.resize(gallery_img_size))
|
||||
|
||||
# Notify the requester that the image is updated
|
||||
if job_info.refresh_active_image_requested.is_set():
|
||||
if job_info.rec_steps_enabled:
|
||||
grid = image_grid(job_info.rec_steps_imgs, 1)
|
||||
job_info.active_image = grid
|
||||
job_info.refresh_active_image_done.set()
|
||||
job_info.refresh_active_image_requested.clear()
|
||||
|
||||
# Interrupt current iteration?
|
||||
if job_info.stop_cur_iter.is_set():
|
||||
job_info.stop_cur_iter.clear()
|
||||
raise StopIteration()
|
||||
|
||||
try:
|
||||
samples_ddim = func_sample(init_data=init_data, x=x, conditioning=c, unconditional_conditioning=uc, sampler_name=sampler_name, img_callback=sample_iteration_callback)
|
||||
except StopIteration:
|
||||
print("Skipping iteration")
|
||||
job_info.job_status = "Skipping iteration"
|
||||
continue
|
||||
samples_ddim = func_sample(init_data=init_data, x=x, conditioning=c, unconditional_conditioning=uc, sampler_name=sampler_name)
|
||||
|
||||
if opt.optimized:
|
||||
modelFS.to(device)
|
||||
|
||||
|
||||
|
||||
x_samples_ddim = (model if not opt.optimized else modelFS).decode_first_stage(samples_ddim)
|
||||
x_samples_ddim = torch.clamp((x_samples_ddim + 1.0) / 2.0, min=0.0, max=1.0)
|
||||
for i, x_sample in enumerate(x_samples_ddim):
|
||||
sanitized_prompt = prompts[i].replace(' ', '_').translate({ord(x): '' for x in invalid_filename_chars})
|
||||
sanitized_prompt = sanitized_prompt.lower()
|
||||
if variant_seed != None and variant_seed != '':
|
||||
if variant_amount == 0.0:
|
||||
seed_used = f"{current_seeds[i]}-{variant_seed}"
|
||||
@ -1106,17 +969,6 @@ def process_images(
|
||||
image = Image.fromarray(x_sample)
|
||||
original_sample = x_sample
|
||||
original_filename = filename
|
||||
|
||||
if resize_mask:
|
||||
scaled_img = image.resize((returned_info["org_width"], returned_info["org_height"]), resample=Image.Resampling.LANCZOS).convert("RGB")
|
||||
scaled_mask = returned_info["cropped_mask"].resize((returned_info["org_width"], returned_info["org_height"]), resample=Image.Resampling.LANCZOS).convert("RGBA")
|
||||
scaled_mask = scaled_mask.filter(ImageFilter.GaussianBlur(mask_blur_strength))
|
||||
returned_info["org_img"].paste(scaled_img, (returned_info["coords"][0], returned_info["coords"][1]), mask=scaled_mask)
|
||||
image = returned_info["org_img"].copy()
|
||||
original_sample = np.asarray(image).astype(np.uint8)
|
||||
#returned_info["org_img"].save(sample_path_i+"\\"+filename+" test.png", format="PNG")
|
||||
|
||||
|
||||
if use_GFPGAN and GFPGAN is not None and not use_RealESRGAN:
|
||||
skip_save = True # #287 >_>
|
||||
torch_gc()
|
||||
@ -1124,12 +976,10 @@ def process_images(
|
||||
gfpgan_sample = restored_img[:,:,::-1]
|
||||
gfpgan_image = Image.fromarray(gfpgan_sample)
|
||||
gfpgan_filename = original_filename + '-gfpgan'
|
||||
if save_each:
|
||||
save_sample(gfpgan_image, sample_path_i, gfpgan_filename, jpg_sample, prompts, seeds, width, height, steps, cfg_scale,
|
||||
normalize_prompt_weights, use_GFPGAN, write_info_files, write_sample_info_to_log_file, prompt_matrix, init_img, uses_loopback, uses_random_seed_loopback, save_each,
|
||||
skip_grid, sort_samples, sampler_name, ddim_eta, n_iter, batch_size, i, denoising_strength, resize_mode, skip_metadata=True)
|
||||
save_sample(gfpgan_image, sample_path_i, gfpgan_filename, jpg_sample, prompts, seeds, width, height, steps, cfg_scale,
|
||||
normalize_prompt_weights, use_GFPGAN, write_info_files, write_sample_info_to_log_file, prompt_matrix, init_img, uses_loopback, uses_random_seed_loopback, skip_save,
|
||||
skip_grid, sort_samples, sampler_name, ddim_eta, n_iter, batch_size, i, denoising_strength, resize_mode, skip_metadata=True)
|
||||
output_images.append(gfpgan_image) #287
|
||||
# save_each = True # #287 >_>
|
||||
#if simple_templating:
|
||||
# grid_captions.append( captions[i] + "\ngfpgan" )
|
||||
|
||||
@ -1140,30 +990,26 @@ def process_images(
|
||||
esrgan_filename = original_filename + '-esrgan4x'
|
||||
esrgan_sample = output[:,:,::-1]
|
||||
esrgan_image = Image.fromarray(esrgan_sample)
|
||||
if save_each:
|
||||
save_sample(esrgan_image, sample_path_i, esrgan_filename, jpg_sample, prompts, seeds, width, height, steps, cfg_scale,
|
||||
normalize_prompt_weights, use_GFPGAN, write_info_files, write_sample_info_to_log_file, prompt_matrix, init_img, uses_loopback, uses_random_seed_loopback, save_each,
|
||||
skip_grid, sort_samples, sampler_name, ddim_eta, n_iter, batch_size, i, denoising_strength, resize_mode, skip_metadata=True)
|
||||
save_sample(esrgan_image, sample_path_i, esrgan_filename, jpg_sample, prompts, seeds, width, height, steps, cfg_scale,
|
||||
normalize_prompt_weights, use_GFPGAN,write_info_files, write_sample_info_to_log_file, prompt_matrix, init_img, uses_loopback, uses_random_seed_loopback, skip_save,
|
||||
skip_grid, sort_samples, sampler_name, ddim_eta, n_iter, batch_size, i, denoising_strength, resize_mode, skip_metadata=True)
|
||||
output_images.append(esrgan_image) #287
|
||||
# save_each = False # #287 >_>
|
||||
#if simple_templating:
|
||||
# grid_captions.append( captions[i] + "\nesrgan" )
|
||||
|
||||
if use_RealESRGAN and RealESRGAN is not None and use_GFPGAN and GFPGAN is not None:
|
||||
skip_save = True # #287 >_>
|
||||
torch_gc()
|
||||
cropped_faces, restored_faces, restored_img = GFPGAN.enhance(original_sample[:,:,::-1], has_aligned=False, only_center_face=False, paste_back=True)
|
||||
cropped_faces, restored_faces, restored_img = GFPGAN.enhance(x_sample[:,:,::-1], has_aligned=False, only_center_face=False, paste_back=True)
|
||||
gfpgan_sample = restored_img[:,:,::-1]
|
||||
output, img_mode = RealESRGAN.enhance(gfpgan_sample[:,:,::-1])
|
||||
gfpgan_esrgan_filename = original_filename + '-gfpgan-esrgan4x'
|
||||
gfpgan_esrgan_sample = output[:,:,::-1]
|
||||
gfpgan_esrgan_image = Image.fromarray(gfpgan_esrgan_sample)
|
||||
if save_each:
|
||||
save_sample(gfpgan_esrgan_image, sample_path_i, gfpgan_esrgan_filename, jpg_sample, prompts, seeds, width, height, steps, cfg_scale,
|
||||
normalize_prompt_weights, use_GFPGAN, write_info_files, write_sample_info_to_log_file, prompt_matrix, init_img, uses_loopback, uses_random_seed_loopback, save_each,
|
||||
skip_grid, sort_samples, sampler_name, ddim_eta, n_iter, batch_size, i, denoising_strength, resize_mode, skip_metadata=True)
|
||||
save_sample(gfpgan_esrgan_image, sample_path_i, gfpgan_esrgan_filename, jpg_sample, prompts, seeds, width, height, steps, cfg_scale,
|
||||
normalize_prompt_weights, use_GFPGAN, write_info_files, write_sample_info_to_log_file, prompt_matrix, init_img, uses_loopback, uses_random_seed_loopback, skip_save,
|
||||
skip_grid, sort_samples, sampler_name, ddim_eta, n_iter, batch_size, i, denoising_strength, resize_mode, skip_metadata=True)
|
||||
output_images.append(gfpgan_esrgan_image) #287
|
||||
# save_each = False # #287 >_>
|
||||
#if simple_templating:
|
||||
# grid_captions.append( captions[i] + "\ngfpgan_esrgan" )
|
||||
|
||||
@ -1171,30 +1017,15 @@ def process_images(
|
||||
if imgProcessorTask == True:
|
||||
output_images.append(image)
|
||||
|
||||
|
||||
if save_each:
|
||||
if not skip_save:
|
||||
save_sample(image, sample_path_i, filename, jpg_sample, prompts, seeds, width, height, steps, cfg_scale,
|
||||
normalize_prompt_weights, use_GFPGAN, write_info_files, write_sample_info_to_log_file, prompt_matrix, init_img, uses_loopback, uses_random_seed_loopback, save_each,
|
||||
skip_grid, sort_samples, sampler_name, ddim_eta, n_iter, batch_size, i, denoising_strength, resize_mode, False)
|
||||
normalize_prompt_weights, use_GFPGAN, write_info_files, write_sample_info_to_log_file, prompt_matrix, init_img, uses_loopback, uses_random_seed_loopback, skip_save,
|
||||
skip_grid, sort_samples, sampler_name, ddim_eta, n_iter, batch_size, i, denoising_strength, resize_mode, False)
|
||||
if add_original_image or not simple_templating:
|
||||
output_images.append(image)
|
||||
if simple_templating:
|
||||
grid_captions.append( captions[i] )
|
||||
|
||||
# Save the progress images?
|
||||
if job_info:
|
||||
if job_info.rec_steps_enabled and (job_info.rec_steps_to_file or job_info.rec_steps_to_gallery):
|
||||
steps_grid = image_grid(job_info.rec_steps_imgs, 1)
|
||||
if job_info.rec_steps_to_gallery:
|
||||
gallery_img_size = tuple(2*dim for dim in image.size)
|
||||
output_images.append( steps_grid.resize( gallery_img_size ) )
|
||||
if job_info.rec_steps_to_file:
|
||||
steps_grid_filename = f"{original_filename}_step_grid"
|
||||
save_sample(steps_grid, sample_path_i, steps_grid_filename, jpg_sample, prompts, seeds, width, height, steps, cfg_scale,
|
||||
normalize_prompt_weights, use_GFPGAN, write_info_files, write_sample_info_to_log_file, prompt_matrix, init_img, uses_loopback, uses_random_seed_loopback, skip_save,
|
||||
skip_grid, sort_samples, sampler_name, ddim_eta, n_iter, batch_size, i, denoising_strength, resize_mode, False)
|
||||
|
||||
|
||||
if opt.optimized:
|
||||
mem = torch.cuda.memory_allocated()/1e6
|
||||
modelFS.to("cpu")
|
||||
@ -1263,7 +1094,7 @@ def txt2img(prompt: str, ddim_steps: int, sampler_name: str, toggles: List[int],
|
||||
seed = seed_to_int(seed)
|
||||
prompt_matrix = 0 in toggles
|
||||
normalize_prompt_weights = 1 in toggles
|
||||
save_each = 2 in toggles
|
||||
skip_save = 2 not in toggles
|
||||
skip_grid = 3 not in toggles
|
||||
sort_samples = 4 in toggles
|
||||
write_info_files = 5 in toggles
|
||||
@ -1302,8 +1133,8 @@ def txt2img(prompt: str, ddim_steps: int, sampler_name: str, toggles: List[int],
|
||||
def init():
|
||||
pass
|
||||
|
||||
def sample(init_data, x, conditioning, unconditional_conditioning, sampler_name, img_callback: Callable = None):
|
||||
samples_ddim, _ = sampler.sample(S=ddim_steps, conditioning=conditioning, batch_size=int(x.shape[0]), shape=x[0].shape, verbose=False, unconditional_guidance_scale=cfg_scale, unconditional_conditioning=unconditional_conditioning, eta=ddim_eta, x_T=x, img_callback=img_callback)
|
||||
def sample(init_data, x, conditioning, unconditional_conditioning, sampler_name):
|
||||
samples_ddim, _ = sampler.sample(S=ddim_steps, conditioning=conditioning, batch_size=int(x.shape[0]), shape=x[0].shape, verbose=False, unconditional_guidance_scale=cfg_scale, unconditional_conditioning=unconditional_conditioning, eta=ddim_eta, x_T=x)
|
||||
return samples_ddim
|
||||
|
||||
try:
|
||||
@ -1314,7 +1145,7 @@ def txt2img(prompt: str, ddim_steps: int, sampler_name: str, toggles: List[int],
|
||||
prompt=prompt,
|
||||
seed=seed,
|
||||
sampler_name=sampler_name,
|
||||
save_each=save_each,
|
||||
skip_save=skip_save,
|
||||
skip_grid=skip_grid,
|
||||
batch_size=batch_size,
|
||||
n_iter=n_iter,
|
||||
@ -1393,9 +1224,14 @@ class Flagging(gr.FlaggingCallback):
|
||||
print("Logged:", filenames[0])
|
||||
|
||||
|
||||
def img2img(prompt: str, image_editor_mode: str, init_info: any, init_info_mask: any, mask_mode: str, mask_blur_strength: int, ddim_steps: int, sampler_name: str,
|
||||
def img2img(prompt: str, image_editor_mode: str, mask_mode: str, mask_blur_strength: int, ddim_steps: int, sampler_name: str,
|
||||
toggles: List[int], realesrgan_model_name: str, n_iter: int, cfg_scale: float, denoising_strength: float,
|
||||
seed: int, height: int, width: int, resize_mode: int, fp = None, job_info: JobInfo = None):
|
||||
seed: int, height: int, width: int, resize_mode: int, init_info: any = None, init_info_mask: any = None, fp = None, job_info: JobInfo = None):
|
||||
print([prompt, image_editor_mode, init_info, init_info_mask, mask_mode,
|
||||
mask_blur_strength, ddim_steps, sampler_name, toggles,
|
||||
realesrgan_model_name, n_iter, cfg_scale,
|
||||
denoising_strength, seed, height, width, resize_mode,
|
||||
fp])
|
||||
outpath = opt.outdir_img2img or opt.outdir or "outputs/img2img-samples"
|
||||
err = False
|
||||
seed = seed_to_int(seed)
|
||||
@ -1406,7 +1242,7 @@ def img2img(prompt: str, image_editor_mode: str, init_info: any, init_info_mask:
|
||||
normalize_prompt_weights = 1 in toggles
|
||||
loopback = 2 in toggles
|
||||
random_seed_loopback = 3 in toggles
|
||||
save_each = 4 in toggles
|
||||
skip_save = 4 not in toggles
|
||||
skip_grid = 5 not in toggles
|
||||
sort_samples = 6 in toggles
|
||||
write_info_files = 7 in toggles
|
||||
@ -1441,44 +1277,35 @@ def img2img(prompt: str, image_editor_mode: str, init_info: any, init_info_mask:
|
||||
raise Exception("Unknown sampler: " + sampler_name)
|
||||
|
||||
if image_editor_mode == 'Mask':
|
||||
global returned_info
|
||||
init_img = init_info_mask["image"]
|
||||
init_img = init_img.convert("RGB")
|
||||
init_img = resize_image(resize_mode, init_img, width, height)
|
||||
image = image.convert("RGB")
|
||||
init_img = init_img.convert("RGB")
|
||||
init_mask = init_info_mask["mask"]
|
||||
init_mask = resize_image(resize_mode, init_mask, width, height)
|
||||
resize_mask = mask_mode == 2
|
||||
|
||||
if resize_mask:
|
||||
returned_info = crop_image(init_img, init_mask, width, height)
|
||||
init_img = returned_info["cropped_img"]
|
||||
init_mask = returned_info["cropped_mask"]
|
||||
|
||||
keep_mask = mask_mode == 0
|
||||
init_mask = init_mask.convert("RGB")
|
||||
init_mask = resize_image(resize_mode, init_mask, width, height)
|
||||
init_mask = init_mask.convert("RGB")
|
||||
keep_mask = mask_mode == 0
|
||||
init_mask = init_mask if keep_mask else ImageOps.invert(init_mask)
|
||||
else:
|
||||
init_img = init_info.convert("RGB")
|
||||
init_img = init_info
|
||||
init_mask = None
|
||||
keep_mask = False
|
||||
resize_mask = False
|
||||
|
||||
assert 0. <= denoising_strength <= 1., 'can only work with strength in [0.0, 1.0]'
|
||||
t_enc = int(denoising_strength * ddim_steps)
|
||||
|
||||
def init():
|
||||
image = init_img.convert("RGB")
|
||||
if resize_mask:
|
||||
image = resize_image(resize_mode, image, width, height)
|
||||
#image = image.convert("RGB") #todo: mask mode -> ValueError: could not convert string to float:
|
||||
image = resize_image(resize_mode, image, width, height)
|
||||
#image = image.convert("RGB")
|
||||
image = np.array(image).astype(np.float32) / 255.0
|
||||
image = image[None].transpose(0, 3, 1, 2)
|
||||
image = torch.from_numpy(image)
|
||||
|
||||
mask_channel = None
|
||||
if image_editor_mode == "Uncrop":
|
||||
alpha = init_img.convert("RGB")
|
||||
alpha = init_img.convert("RGBA")
|
||||
alpha = resize_image(resize_mode, alpha, width // 8, height // 8)
|
||||
mask_channel = alpha.split()[-1]
|
||||
mask_channel = mask_channel.filter(ImageFilter.GaussianBlur(4))
|
||||
@ -1486,7 +1313,7 @@ def img2img(prompt: str, image_editor_mode: str, init_info: any, init_info_mask:
|
||||
mask_channel[mask_channel >= 255] = 255
|
||||
mask_channel[mask_channel < 255] = 0
|
||||
mask_channel = Image.fromarray(mask_channel).filter(ImageFilter.GaussianBlur(2))
|
||||
elif init_mask is not None:
|
||||
elif image_editor_mode == "Mask":
|
||||
alpha = init_mask.convert("RGBA")
|
||||
alpha = resize_image(resize_mode, alpha, width // 8, height // 8)
|
||||
mask_channel = alpha.split()[1]
|
||||
@ -1505,7 +1332,7 @@ def img2img(prompt: str, image_editor_mode: str, init_info: any, init_info_mask:
|
||||
init_image = init_image.to(device)
|
||||
init_image = repeat(init_image, '1 ... -> b ...', b=batch_size)
|
||||
init_latent = (model if not opt.optimized else modelFS).get_first_stage_encoding((model if not opt.optimized else modelFS).encode_first_stage(init_image)) # move to latent space
|
||||
|
||||
|
||||
if opt.optimized:
|
||||
mem = torch.cuda.memory_allocated()/1e6
|
||||
modelFS.to("cpu")
|
||||
@ -1514,7 +1341,7 @@ def img2img(prompt: str, image_editor_mode: str, init_info: any, init_info_mask:
|
||||
|
||||
return init_latent, mask,
|
||||
|
||||
def sample(init_data, x, conditioning, unconditional_conditioning, sampler_name, img_callback: Callable = None):
|
||||
def sample(init_data, x, conditioning, unconditional_conditioning, sampler_name):
|
||||
t_enc_steps = t_enc
|
||||
obliterate = False
|
||||
if ddim_steps == t_enc_steps:
|
||||
@ -1536,7 +1363,7 @@ def img2img(prompt: str, image_editor_mode: str, init_info: any, init_info_mask:
|
||||
|
||||
sigma_sched = sigmas[ddim_steps - t_enc_steps - 1:]
|
||||
model_wrap_cfg = CFGMaskedDenoiser(sampler.model_wrap)
|
||||
samples_ddim = K.sampling.__dict__[f'sample_{sampler.get_sampler_name()}'](model_wrap_cfg, xi, sigma_sched, extra_args={'cond': conditioning, 'uncond': unconditional_conditioning, 'cond_scale': cfg_scale, 'mask': z_mask, 'x0': x0, 'xi': xi}, disable=False, callback=partial(KDiffusionSampler.img_callback_wrapper, img_callback))
|
||||
samples_ddim = K.sampling.__dict__[f'sample_{sampler.get_sampler_name()}'](model_wrap_cfg, xi, sigma_sched, extra_args={'cond': conditioning, 'uncond': unconditional_conditioning, 'cond_scale': cfg_scale, 'mask': z_mask, 'x0': x0, 'xi': xi}, disable=False)
|
||||
else:
|
||||
|
||||
x0, z_mask = init_data
|
||||
@ -1563,7 +1390,17 @@ def img2img(prompt: str, image_editor_mode: str, init_info: any, init_info_mask:
|
||||
history = []
|
||||
initial_seed = None
|
||||
|
||||
do_color_correction = False
|
||||
try:
|
||||
from skimage import exposure
|
||||
do_color_correction = True
|
||||
except:
|
||||
print("Install scikit-image to perform color correction on loopback")
|
||||
|
||||
for i in range(n_iter):
|
||||
if do_color_correction and i == 0:
|
||||
correction_target = cv2.cvtColor(np.asarray(init_img.copy()), cv2.COLOR_RGB2LAB)
|
||||
|
||||
output_images, seed, info, stats = process_images(
|
||||
outpath=outpath,
|
||||
func_init=init,
|
||||
@ -1571,7 +1408,7 @@ def img2img(prompt: str, image_editor_mode: str, init_info: any, init_info_mask:
|
||||
prompt=prompt,
|
||||
seed=seed,
|
||||
sampler_name=sampler_name,
|
||||
save_each=save_each,
|
||||
skip_save=skip_save,
|
||||
skip_grid=skip_grid,
|
||||
batch_size=1,
|
||||
n_iter=1,
|
||||
@ -1605,6 +1442,17 @@ def img2img(prompt: str, image_editor_mode: str, init_info: any, init_info_mask:
|
||||
initial_seed = seed
|
||||
|
||||
init_img = output_images[0]
|
||||
|
||||
if do_color_correction and correction_target is not None:
|
||||
init_img = Image.fromarray(cv2.cvtColor(exposure.match_histograms(
|
||||
cv2.cvtColor(
|
||||
np.asarray(init_img),
|
||||
cv2.COLOR_RGB2LAB
|
||||
),
|
||||
correction_target,
|
||||
channel_axis=2
|
||||
), cv2.COLOR_LAB2RGB).astype("uint8"))
|
||||
|
||||
if not random_seed_loopback:
|
||||
seed = seed + 1
|
||||
else:
|
||||
@ -1630,7 +1478,7 @@ def img2img(prompt: str, image_editor_mode: str, init_info: any, init_info_mask:
|
||||
prompt=prompt,
|
||||
seed=seed,
|
||||
sampler_name=sampler_name,
|
||||
save_each=save_each,
|
||||
skip_save=skip_save,
|
||||
skip_grid=skip_grid,
|
||||
batch_size=batch_size,
|
||||
n_iter=n_iter,
|
||||
@ -1655,7 +1503,6 @@ def img2img(prompt: str, image_editor_mode: str, init_info: any, init_info_mask:
|
||||
write_info_files=write_info_files,
|
||||
write_sample_info_to_log_file=write_sample_info_to_log_file,
|
||||
jpg_sample=jpg_sample,
|
||||
resize_mask=resize_mask,
|
||||
job_info=job_info
|
||||
)
|
||||
|
||||
@ -1723,10 +1570,9 @@ def imgproc(image,image_batch,imgproc_prompt,imgproc_toggles, imgproc_upscale_to
|
||||
output = []
|
||||
images = []
|
||||
def processGFPGAN(image,strength):
|
||||
cvimage = toImgOpenCV(image)
|
||||
cropped_faces, restored_faces, restored_img = GFPGAN.enhance(np.array(cvimage, dtype=np.uint8), has_aligned=False, only_center_face=False, paste_back=True)
|
||||
#save restored image
|
||||
result = toImgPIL(restored_img)
|
||||
image = image.convert("RGB")
|
||||
cropped_faces, restored_faces, restored_img = GFPGAN.enhance(np.array(image, dtype=np.uint8), has_aligned=False, only_center_face=False, paste_back=True)
|
||||
result = Image.fromarray(restored_img)
|
||||
if strength < 1.0:
|
||||
result = Image.blend(image, result, strength)
|
||||
|
||||
@ -1764,7 +1610,7 @@ def imgproc(image,image_batch,imgproc_prompt,imgproc_toggles, imgproc_upscale_to
|
||||
height = int(imgproc_height)
|
||||
cfg_scale = float(imgproc_cfg)
|
||||
denoising_strength = float(imgproc_denoising)
|
||||
save_each = True
|
||||
skip_save = True
|
||||
skip_grid = True
|
||||
prompt = imgproc_prompt
|
||||
t_enc = int(denoising_strength * ddim_steps)
|
||||
@ -1918,7 +1764,7 @@ def imgproc(image,image_batch,imgproc_prompt,imgproc_toggles, imgproc_upscale_to
|
||||
prompt=prompt,
|
||||
seed=seed,
|
||||
sampler_name=sampler_name,
|
||||
save_each=save_each,
|
||||
skip_save=skip_save,
|
||||
skip_grid=skip_grid,
|
||||
batch_size=batch_size,
|
||||
n_iter=n_iter,
|
||||
@ -1964,9 +1810,8 @@ def imgproc(image,image_batch,imgproc_prompt,imgproc_toggles, imgproc_upscale_to
|
||||
return combined_image
|
||||
def processLDSR(image):
|
||||
result = LDSR.superResolution(image,int(imgproc_ldsr_steps),str(imgproc_ldsr_pre_downSample),str(imgproc_ldsr_post_downSample))
|
||||
return result
|
||||
|
||||
|
||||
return result
|
||||
|
||||
|
||||
if image_batch != None:
|
||||
if image != None:
|
||||
@ -1993,7 +1838,7 @@ def imgproc(image,image_batch,imgproc_prompt,imgproc_toggles, imgproc_upscale_to
|
||||
if 1 in imgproc_toggles:
|
||||
if imgproc_upscale_toggles == 0:
|
||||
ModelLoader(['GFPGAN','LDSR'],False,True) # Unload unused models
|
||||
ModelLoader(['RealESGAN'],True,False,imgproc_realesrgan_model_name) # Load used models
|
||||
ModelLoader(['RealESGAN'],True,False,imgproc_realesrgan_model_name) # Load used models
|
||||
elif imgproc_upscale_toggles == 1:
|
||||
ModelLoader(['GFPGAN','LDSR'],False,True) # Unload unused models
|
||||
ModelLoader(['RealESGAN','model'],True,False) # Load used models
|
||||
@ -2106,14 +1951,15 @@ def ModelLoader(models,load=False,unload=False,imgproc_realesrgan_model_name='Re
|
||||
def run_GFPGAN(image, strength):
|
||||
ModelLoader(['LDSR','RealESRGAN'],False,True)
|
||||
ModelLoader(['GFPGAN'],True,False)
|
||||
cvimage = toImgOpenCV(image)
|
||||
cropped_faces, restored_faces, restored_img = GFPGAN.enhance(np.array(cvimage, dtype=np.uint8), has_aligned=False, only_center_face=False, paste_back=True)
|
||||
#save restored image
|
||||
result = toImgPIL(restored_img)
|
||||
if strength < 1.0:
|
||||
result = Image.blend(image, result, strength)
|
||||
image = image.convert("RGB")
|
||||
|
||||
return result
|
||||
cropped_faces, restored_faces, restored_img = GFPGAN.enhance(np.array(image, dtype=np.uint8), has_aligned=False, only_center_face=False, paste_back=True)
|
||||
res = Image.fromarray(restored_img)
|
||||
|
||||
if strength < 1.0:
|
||||
res = Image.blend(image, res, strength)
|
||||
|
||||
return res
|
||||
|
||||
def run_RealESRGAN(image, model_name: str):
|
||||
ModelLoader(['GFPGAN','LDSR'],False,True)
|
||||
@ -2195,9 +2041,9 @@ imgproc_mode_toggles = [
|
||||
'Upscale'
|
||||
]
|
||||
|
||||
sample_img2img = "assets/stable-samples/img2img/sketch-mountains-input.jpg"
|
||||
sample_img2img = sample_img2img if os.path.exists(sample_img2img) else None
|
||||
|
||||
#sample_img2img = "assets/stable-samples/img2img/sketch-mountains-input.jpg"
|
||||
#sample_img2img = sample_img2img if os.path.exists(sample_img2img) else None
|
||||
sample_img2img = None
|
||||
# make sure these indicies line up at the top of img2img()
|
||||
img2img_toggles = [
|
||||
'Create prompt matrix (separate multiple prompts using |, and get all combinations of them)',
|
||||
@ -2226,7 +2072,6 @@ img2img_resize_modes = [
|
||||
"Just resize",
|
||||
"Crop and resize",
|
||||
"Resize and fill",
|
||||
"Resize Masked Area"
|
||||
]
|
||||
|
||||
img2img_defaults = {
|
||||
@ -2262,22 +2107,13 @@ def update_image_mask(cropped_image, resize_mode, width, height):
|
||||
resized_cropped_image = resize_image(resize_mode, cropped_image, width, height) if cropped_image else None
|
||||
return gr.update(value=resized_cropped_image)
|
||||
|
||||
def copy_img_to_input(img):
|
||||
try:
|
||||
image_data = re.sub('^data:image/.+;base64,', '', img)
|
||||
processed_image = Image.open(BytesIO(base64.b64decode(image_data)))
|
||||
tab_update = gr.update(selected='img2img_tab')
|
||||
img_update = gr.update(value=processed_image)
|
||||
return {img2img_image_mask: processed_image, img2img_image_editor: img_update, tabs: tab_update}
|
||||
except IndexError:
|
||||
return [None, None]
|
||||
|
||||
|
||||
def copy_img_to_upscale_esrgan(img):
|
||||
update = gr.update(selected='realesrgan_tab')
|
||||
image_data = re.sub('^data:image/.+;base64,', '', img)
|
||||
processed_image = Image.open(BytesIO(base64.b64decode(image_data)))
|
||||
return {realesrgan_source: processed_image, tabs: update}
|
||||
return {'realesrgan_source': processed_image, 'tabs': update}
|
||||
|
||||
|
||||
help_text = """
|
||||
@ -2341,7 +2177,7 @@ class ServerLauncher(threading.Thread):
|
||||
'inbrowser': opt.inbrowser,
|
||||
'server_name': '0.0.0.0',
|
||||
'server_port': opt.port,
|
||||
'share': opt.share,
|
||||
'share': opt.share,
|
||||
'show_error': True
|
||||
}
|
||||
if not opt.share:
|
||||
|
Loading…
Reference in New Issue
Block a user