img2img-fix (#717)

This commit is contained in:
hlky 2022-09-07 00:48:13 +01:00 committed by GitHub
parent f28255466b
commit 70d4b1ca2a
No known key found for this signature in database
GPG Key ID: 4AEE18F83AFDEB23
5 changed files with 181 additions and 473 deletions

View File

@ -15,14 +15,9 @@ def css(opt):
# TODO: @altryne restore this before merge
if not opt.no_progressbar_hiding:
styling += readTextFile("css", "no_progress_bar.css")
if opt.custom_css:
try:
styling += readTextFile("css", "custom.css")
print("Custom CSS loaded")
except:
pass
return styling
def js(opt):
data = readTextFile("js", "index.js")
data = "(z) => {" + data + "; return z ?? [] }"

View File

@ -57,21 +57,20 @@ def draw_gradio_ui(opt, img2img=lambda x: x, txt2img=lambda x: x,imgproc=lambda
output_txt2img_params = gr.Highlightedtext(label="Generation parameters", interactive=False, elem_id='highlight')
with gr.Group():
with gr.Row(elem_id='txt2img_output_row'):
output_txt2img_copy_params = gr.Button("Copy all").click(
output_txt2img_copy_params = gr.Button("Copy full parameters").click(
inputs=[output_txt2img_params], outputs=[],
_js=js_copy_txt2img_output,
fn=None, show_progress=False)
output_txt2img_seed = gr.Number(label='Seed', interactive=False, visible=False)
output_txt2img_copy_seed = gr.Button("Copy seed").click(
output_txt2img_copy_seed = gr.Button("Copy only seed").click(
inputs=[output_txt2img_seed], outputs=[],
_js='(x) => navigator.clipboard.writeText(x)', fn=None, show_progress=False)
output_txt2img_stats = gr.HTML(label='Stats')
with gr.Column():
with gr.Row():
txt2img_steps = gr.Slider(minimum=1, maximum=250, step=1, label="Sampling Steps",
value=txt2img_defaults['ddim_steps'])
txt2img_sampling = gr.Dropdown(label='Sampling method (k_lms is default k-diffusion sampler)',
txt2img_steps = gr.Slider(minimum=1, maximum=250, step=1, label="Sampling Steps",
value=txt2img_defaults['ddim_steps'])
txt2img_sampling = gr.Dropdown(label='Sampling method (k_lms is default k-diffusion sampler)',
choices=["DDIM", "PLMS", 'k_dpm_2_a', 'k_dpm_2', 'k_euler_a',
'k_euler', 'k_heun', 'k_lms'],
value=txt2img_defaults['sampler_name'])
@ -158,26 +157,20 @@ def draw_gradio_ui(opt, img2img=lambda x: x, txt2img=lambda x: x,imgproc=lambda
img2img_btn_editor = gr.Button("Generate", variant="primary", elem_id="img2img_edit_btn")
with gr.Row().style(equal_height=False):
with gr.Column():
with gr.Tabs():
with gr.TabItem("Img2Img Input"):
#gr.Markdown('#### Img2Img Input')
img2img_image_editor = gr.Image(value=sample_img2img, source="upload", interactive=True,
type="pil", tool="select", elem_id="img2img_editor",
image_mode="RGBA")
img2img_image_mask = gr.Image(value=sample_img2img, source="upload", interactive=True,
type="pil", tool="sketch", visible=False,
elem_id="img2img_mask")
with gr.TabItem("Img2Img Mask Input"):
img2img_mask_input = gr.Image(label="Mask",source="upload", interactive=False,
type="pil", visible=True)
gr.Markdown('#### Img2Img Input')
img2img_image_editor = gr.Image(value=sample_img2img, source="upload", interactive=True,
type="pil", tool="select", elem_id="img2img_editor", image_mode="RGBA"
)
img2img_image_mask = gr.Image(value=sample_img2img, source="upload", interactive=True,
type="pil", tool="sketch", visible=False, image_mode="RGBA",
elem_id="img2img_mask")
with gr.Tabs():
with gr.TabItem("Editor Options"):
with gr.Row():
img2img_image_editor_mode = gr.Radio(choices=["Mask", "Crop", "Uncrop"], label="Image Editor Mode",
value="Crop", elem_id='edit_mode_select')
img2img_mask = gr.Radio(choices=["Keep masked area", "Regenerate only masked area", "Resize and regenerate only masked area"],
img2img_mask = gr.Radio(choices=["Keep masked area", "Regenerate only masked area"],
label="Mask Mode", type="index",
value=img2img_mask_modes[img2img_defaults['mask_mode']], visible=False)
@ -263,16 +256,22 @@ def draw_gradio_ui(opt, img2img=lambda x: x, txt2img=lambda x: x,imgproc=lambda
img2img_image_editor_mode.change(
uifn.change_image_editor_mode,
[img2img_image_editor_mode, img2img_image_editor, img2img_resize, img2img_width, img2img_height],
[img2img_image_editor_mode,
img2img_image_editor,
img2img_image_mask,
img2img_resize,
img2img_width,
img2img_height
],
[img2img_image_editor, img2img_image_mask, img2img_btn_editor, img2img_btn_mask,
img2img_painterro_btn, img2img_mask, img2img_mask_blur_strength, img2img_mask_input]
img2img_painterro_btn, img2img_mask, img2img_mask_blur_strength]
)
img2img_image_editor.edit(
uifn.update_image_mask,
[img2img_image_editor, img2img_resize, img2img_width, img2img_height],
img2img_image_mask
)
# img2img_image_editor_mode.change(
# uifn.update_image_mask,
# [img2img_image_editor, img2img_resize, img2img_width, img2img_height],
# img2img_image_mask
# )
output_txt2img_copy_to_input_btn.click(
uifn.copy_img_to_input,
@ -306,11 +305,11 @@ def draw_gradio_ui(opt, img2img=lambda x: x, txt2img=lambda x: x,imgproc=lambda
)
img2img_func = img2img
img2img_inputs = [img2img_prompt, img2img_image_editor_mode, img2img_image_editor, img2img_image_mask, img2img_mask,
img2img_inputs = [img2img_prompt, img2img_image_editor_mode, img2img_mask,
img2img_mask_blur_strength, img2img_steps, img2img_sampling, img2img_toggles,
img2img_realesrgan_model_name, img2img_batch_count, img2img_cfg,
img2img_denoising, img2img_seed, img2img_height, img2img_width, img2img_resize,
img2img_embeddings, img2img_mask_input]
img2img_image_editor, img2img_image_mask, img2img_embeddings]
img2img_outputs = [output_img2img_gallery, output_img2img_seed, output_img2img_params, output_img2img_stats]
# If a JobManager was passed in then wrap the Generate functions
@ -321,33 +320,23 @@ def draw_gradio_ui(opt, img2img=lambda x: x, txt2img=lambda x: x,imgproc=lambda
outputs=img2img_outputs,
)
def generate(*args):
args_list = list(args)
init_info_mask = args_list[3]
# Get the mask input and remove it from the list
mask_input = args_list[18]
del args_list[18]
# If an external mask is set, use it
if mask_input:
init_info_mask['mask'] = mask_input
args_list[3] = init_info_mask
# Return the result of img2img
return img2img_func(*args_list)
img2img_btn_mask.click(
generate,
img2img_func,
img2img_inputs,
img2img_outputs
)
img2img_btn_editor.click(
img2img_func,
def img2img_submit_params():
#print([img2img_prompt, img2img_image_editor_mode, img2img_mask,
# img2img_mask_blur_strength, img2img_steps, img2img_sampling, img2img_toggles,
# img2img_realesrgan_model_name, img2img_batch_count, img2img_cfg,
# img2img_denoising, img2img_seed, img2img_height, img2img_width, img2img_resize,
# img2img_image_editor, img2img_image_mask, img2img_embeddings])
return (img2img_func,
img2img_inputs,
img2img_outputs)
img2img_btn_editor.click(*img2img_submit_params())
# GENERATE ON ENTER
img2img_prompt.submit(None, None, None,
_js=call_JS("clickFirstVisibleButton",
@ -374,7 +363,7 @@ def draw_gradio_ui(opt, img2img=lambda x: x, txt2img=lambda x: x,imgproc=lambda
# value=gfpgan_defaults['strength'])
#select folder with images to process
with gr.TabItem('Batch Process'):
imgproc_folder = gr.File(label="Batch Process", file_count="multiple", interactive=True, type="file")
imgproc_folder = gr.File(label="Batch Process", file_count="multiple",source="upload", interactive=True, type="file")
imgproc_pngnfo = gr.Textbox(label="PNG Metadata", placeholder="PngNfo", visible=False, max_lines=5)
with gr.Row():
imgproc_btn = gr.Button("Process", variant="primary")
@ -580,7 +569,7 @@ def draw_gradio_ui(opt, img2img=lambda x: x, txt2img=lambda x: x,imgproc=lambda
<div id="90" style="max-width: 100%; font-size: 14px; text-align: center;" class="output-markdown gr-prose border-solid border border-gray-200 rounded gr-panel">
<p>For help and advanced usage guides, visit the <a href="https://github.com/hlky/stable-diffusion-webui/wiki" target="_blank">Project Wiki</a></p>
<p>Stable Diffusion WebUI is an open-source project. You can find the latest stable builds on the <a href="https://github.com/hlky/stable-diffusion" target="_blank">main repository</a>.
If you would like to contribute to development or test bleeding edge builds, you can visit the <a href="https://github.com/hlky/stable-diffusion-webui" target="_blank">development repository</a>.</p>
If you would like to contribute to development or test bleeding edge builds, you can visit the <a href="https://github.com/hlky/stable-diffusion-webui" target="_blank">developement repository</a>.</p>
</div>
""")
# Hack: Detect the load event on the frontend

View File

@ -1,7 +1,7 @@
''' Provides simple job management for gradio, allowing viewing and stopping in-progress multi-batch generations '''
from __future__ import annotations
import gradio as gr
from gradio.components import Component, Gallery, Slider
from gradio.components import Component, Gallery
from threading import Event, Timer
from typing import Callable, List, Dict, Tuple, Optional, Any
from dataclasses import dataclass, field
@ -30,17 +30,7 @@ class JobInfo:
session_key: str
job_token: Optional[int] = None
images: List[Image] = field(default_factory=list)
active_image: Image = None
rec_steps_enabled: bool = False
rec_steps_imgs: List[Image] = field(default_factory=list)
rec_steps_intrvl: int = None
rec_steps_to_gallery: bool = False
rec_steps_to_file: bool = False
should_stop: Event = field(default_factory=Event)
refresh_active_image_requested: Event = field(default_factory=Event)
refresh_active_image_done: Event = field(default_factory=Event)
stop_cur_iter: Event = field(default_factory=Event)
active_iteration_cnt: int = field(default_factory=int)
job_status: str = field(default_factory=str)
finished: bool = False
removed_output_idxs: List[int] = field(default_factory=list)
@ -86,7 +76,7 @@ class JobManagerUi:
'''
return self._job_manager._wrap_func(
func=func, inputs=inputs, outputs=outputs,
job_ui=self
refresh_btn=self._refresh_btn, stop_btn=self._stop_btn, status_text=self._status_text
)
_refresh_btn: gr.Button
@ -94,13 +84,6 @@ class JobManagerUi:
_status_text: gr.Textbox
_stop_all_session_btn: gr.Button
_free_done_sessions_btn: gr.Button
_active_image: gr.Image
_active_image_stop_btn: gr.Button
_active_image_refresh_btn: gr.Button
_rec_steps_intrvl_sldr: gr.Slider
_rec_steps_checkbox: gr.Checkbox
_save_rec_steps_to_gallery_chkbx: gr.Checkbox
_save_rec_steps_to_file_chkbx: gr.Checkbox
_job_manager: JobManager
@ -119,23 +102,11 @@ class JobManager:
'''
assert gr.context.Context.block is not None, "draw_gradio_ui must be called within a 'gr.Blocks' 'with' context"
with gr.Tabs():
with gr.TabItem("Job Controls"):
with gr.TabItem("Current Session"):
with gr.Row():
stop_btn = gr.Button("Stop All Batches", elem_id="stop", variant="secondary")
refresh_btn = gr.Button("Refresh Finished Batches", elem_id="refresh", variant="secondary")
stop_btn = gr.Button("Stop", elem_id="stop", variant="secondary")
refresh_btn = gr.Button("Refresh", elem_id="refresh", variant="secondary")
status_text = gr.Textbox(placeholder="Job Status", interactive=False, show_label=False)
with gr.Row():
active_image_stop_btn = gr.Button("Skip Active Batch", variant="secondary")
active_image_refresh_btn = gr.Button("View Batch Progress", variant="secondary")
active_image = gr.Image(type="pil", interactive=False, visible=False, elem_id="active_iteration_image")
with gr.TabItem("Batch Progress Settings"):
with gr.Row():
record_steps_checkbox = gr.Checkbox(value=False, label="Enable Batch Progress Grid")
record_steps_interval_slider = gr.Slider(
value=3, label="Record Interval (steps)", minimum=1, maximum=25, step=1)
with gr.Row() as record_steps_box:
steps_to_gallery_checkbox = gr.Checkbox(value=False, label="Save Progress Grid to Gallery")
steps_to_file_checkbox = gr.Checkbox(value=False, label="Save Progress Grid to File")
with gr.TabItem("Maintenance"):
with gr.Row():
gr.Markdown(
@ -147,15 +118,9 @@ class JobManager:
free_done_sessions_btn = gr.Button(
"Clear Finished Jobs", elem_id="clear_finished", variant="secondary"
)
return JobManagerUi(_refresh_btn=refresh_btn, _stop_btn=stop_btn, _status_text=status_text,
_stop_all_session_btn=stop_all_sessions_btn, _free_done_sessions_btn=free_done_sessions_btn,
_active_image=active_image, _active_image_stop_btn=active_image_stop_btn,
_active_image_refresh_btn=active_image_refresh_btn,
_rec_steps_checkbox=record_steps_checkbox,
_save_rec_steps_to_gallery_chkbx=steps_to_gallery_checkbox,
_save_rec_steps_to_file_chkbx=steps_to_file_checkbox,
_rec_steps_intrvl_sldr=record_steps_interval_slider, _job_manager=self)
_job_manager=self)
def clear_all_finished_jobs(self):
''' Removes all currently finished jobs, across all sessions.
@ -169,7 +134,6 @@ class JobManager:
for session in self._sessions.values():
for job in session.jobs.values():
job.should_stop.set()
job.stop_cur_iter.set()
def _get_job_token(self, block: bool = False) -> Optional[int]:
''' Attempts to acquire a job token, optionally blocking until available '''
@ -211,26 +175,6 @@ class JobManager:
job_info.should_stop.set()
return "Stopping after current batch finishes"
def _refresh_cur_iter_func(self, func_key: FuncKey, session_key: str) -> List[Component]:
''' Updates information from the active iteration '''
session_info, job_info = self._get_call_info(func_key, session_key)
if job_info is None:
return [None, f"Session {session_key} was not running function {func_key}"]
job_info.refresh_active_image_requested.set()
if job_info.refresh_active_image_done.wait(timeout=20.0):
job_info.refresh_active_image_done.clear()
return [gr.Image.update(value=job_info.active_image, visible=True), f"Sample iteration {job_info.active_iteration_cnt}"]
return [gr.Image.update(visible=False), "Timed out getting image"]
def _stop_cur_iter_func(self, func_key: FuncKey, session_key: str) -> List[Component]:
''' Marks that the active iteration should be stopped'''
session_info, job_info = self._get_call_info(func_key, session_key)
if job_info is None:
return [None, f"Session {session_key} was not running function {func_key}"]
job_info.stop_cur_iter.set()
return [gr.Image.update(visible=False), "Stopping current iteration"]
def _get_call_info(self, func_key: FuncKey, session_key: str) -> Tuple[SessionInfo, JobInfo]:
''' Helper to get the SessionInfo and JobInfo. '''
session_info = self._sessions.get(session_key, None)
@ -263,8 +207,7 @@ class JobManager:
def _pre_call_func(
self, func_key: FuncKey, output_dummy_obj: Component, refresh_btn: gr.Button, stop_btn: gr.Button,
status_text: gr.Textbox, active_image: gr.Image, active_refresh_btn: gr.Button, active_stop_btn: gr.Button,
session_key: str) -> List[Component]:
status_text: gr.Textbox, session_key: str) -> List[Component]:
''' Called when a job is about to start '''
session_info, job_info = self._get_call_info(func_key, session_key)
@ -276,9 +219,7 @@ class JobManager:
return {output_dummy_obj: triggerChangeEvent(),
refresh_btn: gr.Button.update(variant="primary", value=refresh_btn.value),
stop_btn: gr.Button.update(variant="primary", value=stop_btn.value),
status_text: gr.Textbox.update(value="Generation has started. Click 'Refresh' to see finished images, 'View Batch Progress' for active images"),
active_refresh_btn: gr.Button.update(variant="primary", value=active_refresh_btn.value),
active_stop_btn: gr.Button.update(variant="primary", value=active_stop_btn.value),
status_text: gr.Textbox.update(value="Generation has started. Click 'Refresh' for updates")
}
def _call_func(self, func_key: FuncKey, session_key: str) -> List[Component]:
@ -292,7 +233,7 @@ class JobManager:
except Exception as e:
job_info.job_status = f"Error: {e}"
print(f"Exception processing job {job_info}: {e}\n{traceback.format_exc()}")
raise
outputs = []
# Filter the function output for any removed outputs
filtered_output = []
@ -313,16 +254,12 @@ class JobManager:
def _post_call_func(
self, func_key: FuncKey, output_dummy_obj: Component, refresh_btn: gr.Button, stop_btn: gr.Button,
status_text: gr.Textbox, active_image: gr.Image, active_refresh_btn: gr.Button, active_stop_btn: gr.Button,
session_key: str) -> List[Component]:
status_text: gr.Textbox, session_key: str) -> List[Component]:
''' Called when a job completes '''
return {output_dummy_obj: triggerChangeEvent(),
refresh_btn: gr.Button.update(variant="secondary", value=refresh_btn.value),
stop_btn: gr.Button.update(variant="secondary", value=stop_btn.value),
status_text: gr.Textbox.update(value="Generation has finished!"),
active_refresh_btn: gr.Button.update(variant="secondary", value=active_refresh_btn.value),
active_stop_btn: gr.Button.update(variant="secondary", value=active_stop_btn.value),
active_image: gr.Image.update(visible=False)
status_text: gr.Textbox.update(value="Generation has finished!")
}
def _update_gallery_event(self, func_key: FuncKey, session_key: str) -> List[Component]:
@ -338,15 +275,16 @@ class JobManager:
return job_info.images
def _wrap_func(self, func: Callable, inputs: List[Component],
outputs: List[Component],
job_ui: JobManagerUi) -> Tuple[Callable, List[Component]]:
def _wrap_func(
self, func: Callable, inputs: List[Component], outputs: List[Component],
refresh_btn: gr.Button = None, stop_btn: gr.Button = None,
status_text: Optional[gr.Textbox] = None) -> Tuple[Callable, List[Component]]:
''' handles JobManageUI's wrap_func'''
assert gr.context.Context.block is not None, "wrap_func must be called within a 'gr.Blocks' 'with' context"
# Create a unique key for this job
func_key = FuncKey(job_id=uuid.uuid4().hex, func=func)
func_key = FuncKey(job_id=uuid.uuid4(), func=func)
# Create a unique session key (next gradio release can use gr.State, see https://gradio.app/state_in_blocks/)
if self._session_key is None:
@ -364,6 +302,9 @@ class JobManager:
del outputs[idx]
break
# Add the session key to the inputs
inputs += [self._session_key]
# Create dummy objects
update_gallery_obj = gr.JSON(visible=False, elem_id="JobManagerDummyObject")
update_gallery_obj.change(
@ -372,44 +313,20 @@ class JobManager:
[gallery_comp]
)
if job_ui._refresh_btn:
job_ui._refresh_btn.variant = 'secondary'
job_ui._refresh_btn.click(
if refresh_btn:
refresh_btn.variant = 'secondary'
refresh_btn.click(
partial(self._refresh_func, func_key),
[self._session_key],
[update_gallery_obj, job_ui._status_text]
[update_gallery_obj, status_text]
)
if job_ui._stop_btn:
job_ui._stop_btn.variant = 'secondary'
job_ui._stop_btn.click(
if stop_btn:
stop_btn.variant = 'secondary'
stop_btn.click(
partial(self._stop_wrapped_func, func_key),
[self._session_key],
[job_ui._status_text]
)
if job_ui._active_image and job_ui._active_image_refresh_btn:
job_ui._active_image_refresh_btn.click(
partial(self._refresh_cur_iter_func, func_key),
[self._session_key],
[job_ui._active_image, job_ui._status_text]
)
if job_ui._active_image_stop_btn:
job_ui._active_image_stop_btn.click(
partial(self._stop_cur_iter_func, func_key),
[self._session_key],
[job_ui._active_image, job_ui._status_text]
)
if job_ui._stop_all_session_btn:
job_ui._stop_all_session_btn.click(
self.stop_all_jobs, [], []
)
if job_ui._free_done_sessions_btn:
job_ui._free_done_sessions_btn.click(
self.clear_all_finished_jobs, [], []
[status_text]
)
# (ab)use gr.JSON to forward events.
@ -426,8 +343,7 @@ class JobManager:
# Since some parameters are optional it makes sense to use the 'dict' return value type, which requires
# the Component as a key... so group together the UI components that the event listeners are going to update
# to make it easy to append to function calls and outputs
job_ui_params = [job_ui._refresh_btn, job_ui._stop_btn, job_ui._status_text,
job_ui._active_image, job_ui._active_image_refresh_btn, job_ui._active_image_stop_btn]
job_ui_params = [refresh_btn, stop_btn, status_text]
job_ui_outputs = [comp for comp in job_ui_params if comp is not None]
# Here a chain is constructed that will make a 'pre' call, a 'run' call, and a 'post' call,
@ -453,39 +369,27 @@ class JobManager:
[call_dummyobj] + job_ui_outputs
)
# Add any components that we want the runtime values for
added_inputs = [self._session_key, job_ui._rec_steps_checkbox, job_ui._save_rec_steps_to_gallery_chkbx,
job_ui._save_rec_steps_to_file_chkbx, job_ui._rec_steps_intrvl_sldr]
# Now replace the original function with one that creates a JobInfo and triggers the dummy obj
def wrapped_func(*wrapped_inputs):
# Remove the added_inputs (pop opposite order of list)
wrapped_inputs = list(wrapped_inputs)
rec_steps_interval: int = wrapped_inputs.pop()
save_rec_steps_file: bool = wrapped_inputs.pop()
save_rec_steps_grid: bool = wrapped_inputs.pop()
record_steps_enabled: bool = wrapped_inputs.pop()
session_key: str = wrapped_inputs.pop()
job_inputs = tuple(wrapped_inputs)
def wrapped_func(*inputs):
session_key = inputs[-1]
inputs = inputs[:-1]
# Get or create a session for this key
session_info = self._sessions.setdefault(session_key, SessionInfo())
# Is this session already running this job?
if func_key in session_info.jobs:
return {job_ui._status_text: "This session is already running that function!"}
return {status_text: "This session is already running that function!"}
job_token = self._get_job_token(block=False)
job = JobInfo(
inputs=job_inputs, func=func, removed_output_idxs=removed_idxs, session_key=session_key,
job_token=job_token, rec_steps_enabled=record_steps_enabled, rec_steps_intrvl=rec_steps_interval,
rec_steps_to_gallery=save_rec_steps_grid, rec_steps_to_file=save_rec_steps_file)
job = JobInfo(inputs=inputs, func=func, removed_output_idxs=removed_idxs, session_key=session_key,
job_token=job_token)
session_info.jobs[func_key] = job
ret = {pre_call_dummyobj: triggerChangeEvent()}
if job_token is None:
ret[job_ui._status_text] = "Job is queued"
ret[status_text] = "Job is queued"
return ret
return wrapped_func, inputs + added_inputs, [pre_call_dummyobj, job_ui._status_text]
return wrapped_func, inputs, [pre_call_dummyobj, status_text]

View File

@ -6,33 +6,17 @@ import base64
import re
def change_image_editor_mode(choice, cropped_image, resize_mode, width, height):
def change_image_editor_mode(choice, cropped_image, masked_image, resize_mode, width, height):
if choice == "Mask":
return [gr.Image.update(visible=False),
gr.Image.update(visible=True),
gr.Button.update("Generate", variant="primary", visible=False),
gr.Button.update("Generate", variant="primary", visible=True),
gr.Button.update("Advanced Editor", visible=False),
gr.Radio.update(choices=["Keep masked area", "Regenerate only masked area"],
label="Mask Mode",
value="Regenerate only masked area", visible=True),
gr.Slider.update(minimum=1, maximum=10, step=1, label="How much blurry should the mask be? (to avoid hard edges)", value=3, visible=True),
gr.Image.update(interactive=True)]
else:
return [gr.Image.update(visible=True),
gr.Image.update(visible=False),
gr.Button.update("Generate", variant="primary", visible=True),
gr.Button.update("Generate", variant="primary", visible=False),
gr.Button.update("Advanced Editor", visible=True),
gr.Radio.update(choices=["Keep masked area", "Regenerate only masked area"],
label="Mask Mode",
value="Regenerate only masked area", visible=False),
gr.Slider.update(minimum=1, maximum=10, step=1, label="How much blurry should the mask be? (to avoid hard edges)", value=3, visible=False),
gr.Image.update(interactive=False)]
update_image_result = update_image_mask(cropped_image, resize_mode, width, height)
return [gr.update(visible=False), update_image_result, gr.update(visible=False), gr.update(visible=True), gr.update(visible=False), gr.update(visible=True), gr.update(visible=True)]
update_image_result = update_image_mask(masked_image["image"], resize_mode, width, height)
return [update_image_result, gr.update(visible=False), gr.update(visible=True), gr.update(visible=False), gr.update(visible=True), gr.update(visible=False), gr.update(visible=False)]
def update_image_mask(cropped_image, resize_mode, width, height):
resized_cropped_image = resize_image(resize_mode, cropped_image, width, height) if cropped_image else None
return gr.Image.update(value=resized_cropped_image)
return gr.update(value=resized_cropped_image, visible=True)
def toggle_options_gfpgan(selection):
if 0 in selection:

View File

@ -1,5 +1,7 @@
import argparse, os, sys, glob, re
import cv2
from frontend.frontend import draw_gradio_ui
from frontend.job_manager import JobManager, JobInfo
from frontend.ui_functions import resize_image
@ -37,11 +39,9 @@ parser.add_argument("--save-metadata", action='store_true', help="Store generati
parser.add_argument("--share-password", type=str, help="Sharing is open by default, use this to set a password. Username: webui", default=None)
parser.add_argument("--share", action='store_true', help="Should share your server on gradio.app, this allows you to use the UI from your mobile app", default=False)
parser.add_argument("--skip-grid", action='store_true', help="do not save a grid, only individual samples. Helpful when evaluating lots of samples", default=False)
parser.add_argument("--save-each", action='store_true', help="save individual samples. For speed measurements.", default=False)
parser.add_argument("--skip-save", action='store_true', help="do not save indiviual samples. For speed measurements.", default=False)
parser.add_argument('--no-job-manager', action='store_true', help="Don't use the experimental job manager on top of gradio", default=False)
parser.add_argument("--max-jobs", type=int, help="Maximum number of concurrent 'generate' commands", default=1)
parser.add_argument("--custom-css", action='store_true', help="Place custom.css in css folder to load a custom theme of the UI", default=False)
opt = parser.parse_args()
#Should not be needed anymore
@ -66,12 +66,9 @@ import torch
import torch.nn as nn
import yaml
import glob
import copy
from typing import List, Union, Dict, Callable, Any
from typing import List, Union, Dict
from pathlib import Path
from collections import namedtuple
import cv2
from functools import partial
from contextlib import contextmanager, nullcontext
from einops import rearrange, repeat
@ -109,7 +106,6 @@ invalid_filename_chars = '<>:"/\|?*\n'
GFPGAN_dir = opt.gfpgan_dir
RealESRGAN_dir = opt.realesrgan_dir
LDSR_dir = opt.ldsr_dir
returned_info = {}
if opt.optimized_turbo:
opt.optimized = True
@ -140,13 +136,6 @@ elif grid_format[0] == 'webp':
grid_quality = abs(grid_quality)
def toImgOpenCV(imgPIL): # Conver imgPIL to imgOpenCV
i = np.array(imgPIL) # After mapping from PIL to numpy : [R,G,B,A]
# numpy Image Channel system: [B,G,R,A]
red = i[:,:,0].copy(); i[:,:,0] = i[:,:,2].copy(); i[:,:,2] = red
return i
def toImgPIL(imgOpenCV): return Image.fromarray(cv2.cvtColor(imgOpenCV, cv2.COLOR_BGR2RGB))
def chunk(it, size):
it = iter(it)
return iter(lambda: tuple(islice(it, size)), ())
@ -275,21 +264,15 @@ class KDiffusionSampler:
self.schedule = sampler
def get_sampler_name(self):
return self.schedule
def sample(self, S, conditioning, batch_size, shape, verbose, unconditional_guidance_scale, unconditional_conditioning, eta, x_T, img_callback: Callable = None ):
def sample(self, S, conditioning, batch_size, shape, verbose, unconditional_guidance_scale, unconditional_conditioning, eta, x_T):
sigmas = self.model_wrap.get_sigmas(S)
x = x_T * sigmas[0]
model_wrap_cfg = CFGDenoiser(self.model_wrap)
samples_ddim = K.sampling.__dict__[f'sample_{self.schedule}'](model_wrap_cfg, x, sigmas, extra_args={'cond': conditioning, 'uncond': unconditional_conditioning, 'cond_scale': unconditional_guidance_scale}, disable=False, callback=partial(KDiffusionSampler.img_callback_wrapper, img_callback))
samples_ddim = K.sampling.__dict__[f'sample_{self.schedule}'](model_wrap_cfg, x, sigmas, extra_args={'cond': conditioning, 'uncond': unconditional_conditioning, 'cond_scale': unconditional_guidance_scale}, disable=False)
return samples_ddim, None
@classmethod
def img_callback_wrapper(cls, callback: Callable, *args):
''' Converts a KDiffusion callback to the standard img_callback '''
if callback:
arg_dict = args[0]
callback(image_sample=arg_dict['denoised'], iter_num=arg_dict['i'])
def create_random_tensors(shape, seeds):
xs = []
@ -524,7 +507,6 @@ def seed_to_int(s):
n = n >> 32
return n
def draw_prompt_matrix(im, width, height, all_prompts):
def wrap(text, d, font, line_length):
lines = ['']
@ -590,63 +572,6 @@ def draw_prompt_matrix(im, width, height, all_prompts):
return result
def round_to_multiple(dimension, dimension_ceiling, multiple=64, round_down=True):
if round_down:
rounded_dimension = multiple * math.ceil(dimension / multiple)
else:
rounded_dimension = multiple * math.floor(dimension / multiple)
return rounded_dimension
def crop_image(img, mask, width, height):
def get_mask_and_img(img, mask,dimension, coords, target_width, target_height):
longest_target_dimension = round_to_multiple(dimension, dimension)
func_crop_coords = (coords[0], coords[1], coords[0]+longest_target_dimension, coords[1]+longest_target_dimension)
resized_img = img.crop(func_crop_coords)
scale_dimension = target_width if target_width > target_height else target_height
resized_img = resized_img.resize((scale_dimension, scale_dimension), resample=Image.Resampling.LANCZOS)
resized_mask = mask.crop(func_crop_coords)
cropped_img_width, cropped_img_height = resized_mask.size
resized_mask = resized_mask.resize((scale_dimension, scale_dimension), resample=Image.Resampling.LANCZOS)
alpha_mask = resized_mask.convert("RGBA")
mask_data = alpha_mask.getdata()
container = []
for item in mask_data:
if item[0] == 0 and item[1] == 0 and item[2] == 0:
container.append((255, 255, 255, 0))
else:
container.append(item)
alpha_mask.putdata(container)
results = {
"cropped_img": resized_img,
"org_img": rgb_image,
"cropped_mask": alpha_mask,
"coords": crop_coords,
"scale_width": width,
"scale_height": height,
"org_width": cropped_img_width,
"org_height": cropped_img_height
}
return results
rgb_image = img.convert("RGB")
rgb_mask = mask.convert("RGB")
np_mask = np.array(rgb_mask)
white_columns = np.where(np_mask.max(axis=0)>= 255)[0]
white_rows = np.where(np_mask.max(axis=1)>= 255)[0]
crop_coords = (min(white_columns), min(white_rows), max(white_columns), max(white_rows))
crop_to_size = rgb_image.crop(crop_coords)
cropped_img_width, cropped_img_height = crop_to_size.size
if cropped_img_width > cropped_img_height:
results_dict = get_mask_and_img(rgb_image, mask, cropped_img_width, crop_coords, width, height)
else:
results_dict = get_mask_and_img(rgb_image, mask, cropped_img_height, crop_coords, width, height)
return results_dict
def check_prompt_length(prompt, comments):
@ -668,8 +593,8 @@ def check_prompt_length(prompt, comments):
comments.append(f"Warning: too many input tokens; some ({len(overflowing_words)}) have been truncated:\n{overflowing_text}\n")
def save_sample(image, sample_path_i, filename, jpg_sample, prompts, seeds, width, height, steps, cfg_scale,
normalize_prompt_weights, use_GFPGAN, write_info_files, write_sample_info_to_log_file, prompt_matrix, init_img, uses_loopback, uses_random_seed_loopback, save_each,
skip_grid, sort_samples, sampler_name, ddim_eta, n_iter, batch_size, i, denoising_strength, resize_mode, skip_metadata=True):
normalize_prompt_weights, use_GFPGAN, write_info_files, write_sample_info_to_log_file, prompt_matrix, init_img, uses_loopback, uses_random_seed_loopback, skip_save,
skip_grid, sort_samples, sampler_name, ddim_eta, n_iter, batch_size, i, denoising_strength, resize_mode, skip_metadata=True):
filename_i = os.path.join(sample_path_i, filename)
if not jpg_sample:
if opt.save_metadata and not skip_metadata:
@ -702,7 +627,7 @@ def save_sample(image, sample_path_i, filename, jpg_sample, prompts, seeds, widt
toggles.append(2)
if uses_random_seed_loopback:
toggles.append(3)
if save_each:
if not skip_save:
toggles.append(2 + offset)
if not skip_grid:
toggles.append(3 + offset)
@ -852,12 +777,12 @@ def oxlamon_matrix(prompt, seed, n_iter, batch_size):
def process_images(
outpath, func_init, func_sample, prompt, seed, sampler_name, skip_grid, save_each, batch_size,
outpath, func_init, func_sample, prompt, seed, sampler_name, skip_grid, skip_save, batch_size,
n_iter, steps, cfg_scale, width, height, prompt_matrix, use_GFPGAN, use_RealESRGAN, realesrgan_model_name,
fp, ddim_eta=0.0, do_not_save_grid=False, normalize_prompt_weights=True, init_img=None, init_mask=None,
keep_mask=False, mask_blur_strength=3, denoising_strength=0.75, resize_mode=None, uses_loopback=False,
uses_random_seed_loopback=False, sort_samples=True, write_info_files=True, write_sample_info_to_log_file=False, jpg_sample=False,
variant_amount=0.0, variant_seed=None,imgProcessorTask=False,resize_mask=False, job_info: JobInfo = None):
variant_amount=0.0, variant_seed=None,imgProcessorTask=False, job_info: JobInfo = None):
"""this is the main loop that both txt2img and img2img use; it calls func_init once inside all the scopes and func_sample once per batch"""
prompt = prompt or ''
torch_gc()
@ -956,7 +881,6 @@ def process_images(
if job_info:
job_info.job_status = f"Processing Iteration {n+1}/{n_iter}. Batch size {batch_size}"
job_info.rec_steps_imgs.clear()
for idx,(p,s) in enumerate(zip(prompts,seeds)):
job_info.job_status += f"\nItem {idx}: Seed {s}\nPrompt: {p}"
@ -1010,78 +934,17 @@ def process_images(
# finally, slerp base_x noise to target_x noise for creating a variant
x = slerp(device, max(0.0, min(1.0, cur_variant_amount)), base_x, target_x)
# If in optimized mode then make a CPU-copy of the model to generate preview images
if opt.optimized:
step_preview_model = copy.deepcopy(modelFS).to("cpu")
if not opt.no_half:
step_preview_model.float()
else:
step_preview_model = model
def sample_iteration_callback(image_sample: torch.Tensor, iter_num: int):
''' Called from the sampler every iteration '''
if job_info:
job_info.active_iteration_cnt = iter_num
record_periodic_image = job_info.rec_steps_enabled and (0 == iter_num % job_info.rec_steps_intrvl)
if record_periodic_image or job_info.refresh_active_image_requested.is_set():
preview_start_time = time.time()
if opt.optimized:
image_sample = image_sample.to("cpu")
batch_ddim = step_preview_model.decode_first_stage(image_sample)
batch_ddim = torch.clamp((batch_ddim + 1.0) / 2.0, min=0.0, max=1.0)
preview_elapsed_timed = time.time() - preview_start_time
if preview_elapsed_timed > 1:
print(
f"Warning: Preview generation is slow! It took {preview_elapsed_timed:.2f}s to generate one preview!")
images: List[Image.Image] = []
# Convert tensor to image (copied from code below)
for ddim in batch_ddim:
x_sample = 255. * rearrange(ddim.cpu().numpy(), 'c h w -> h w c')
x_sample = x_sample.astype(np.uint8)
image = Image.fromarray(x_sample)
images.append(image)
caption = f"Iter {iter_num}"
grid = image_grid(images, len(images), force_n_rows=1, captions=[caption]*len(images))
# Save the images if recording steps, and append existing saved steps
if job_info.rec_steps_enabled:
gallery_img_size = tuple( int(0.25*dim) for dim in images[0].size)
job_info.rec_steps_imgs.append(grid.resize(gallery_img_size))
# Notify the requester that the image is updated
if job_info.refresh_active_image_requested.is_set():
if job_info.rec_steps_enabled:
grid = image_grid(job_info.rec_steps_imgs, 1)
job_info.active_image = grid
job_info.refresh_active_image_done.set()
job_info.refresh_active_image_requested.clear()
# Interrupt current iteration?
if job_info.stop_cur_iter.is_set():
job_info.stop_cur_iter.clear()
raise StopIteration()
try:
samples_ddim = func_sample(init_data=init_data, x=x, conditioning=c, unconditional_conditioning=uc, sampler_name=sampler_name, img_callback=sample_iteration_callback)
except StopIteration:
print("Skipping iteration")
job_info.job_status = "Skipping iteration"
continue
samples_ddim = func_sample(init_data=init_data, x=x, conditioning=c, unconditional_conditioning=uc, sampler_name=sampler_name)
if opt.optimized:
modelFS.to(device)
x_samples_ddim = (model if not opt.optimized else modelFS).decode_first_stage(samples_ddim)
x_samples_ddim = torch.clamp((x_samples_ddim + 1.0) / 2.0, min=0.0, max=1.0)
for i, x_sample in enumerate(x_samples_ddim):
sanitized_prompt = prompts[i].replace(' ', '_').translate({ord(x): '' for x in invalid_filename_chars})
sanitized_prompt = sanitized_prompt.lower()
if variant_seed != None and variant_seed != '':
if variant_amount == 0.0:
seed_used = f"{current_seeds[i]}-{variant_seed}"
@ -1106,17 +969,6 @@ def process_images(
image = Image.fromarray(x_sample)
original_sample = x_sample
original_filename = filename
if resize_mask:
scaled_img = image.resize((returned_info["org_width"], returned_info["org_height"]), resample=Image.Resampling.LANCZOS).convert("RGB")
scaled_mask = returned_info["cropped_mask"].resize((returned_info["org_width"], returned_info["org_height"]), resample=Image.Resampling.LANCZOS).convert("RGBA")
scaled_mask = scaled_mask.filter(ImageFilter.GaussianBlur(mask_blur_strength))
returned_info["org_img"].paste(scaled_img, (returned_info["coords"][0], returned_info["coords"][1]), mask=scaled_mask)
image = returned_info["org_img"].copy()
original_sample = np.asarray(image).astype(np.uint8)
#returned_info["org_img"].save(sample_path_i+"\\"+filename+" test.png", format="PNG")
if use_GFPGAN and GFPGAN is not None and not use_RealESRGAN:
skip_save = True # #287 >_>
torch_gc()
@ -1124,12 +976,10 @@ def process_images(
gfpgan_sample = restored_img[:,:,::-1]
gfpgan_image = Image.fromarray(gfpgan_sample)
gfpgan_filename = original_filename + '-gfpgan'
if save_each:
save_sample(gfpgan_image, sample_path_i, gfpgan_filename, jpg_sample, prompts, seeds, width, height, steps, cfg_scale,
normalize_prompt_weights, use_GFPGAN, write_info_files, write_sample_info_to_log_file, prompt_matrix, init_img, uses_loopback, uses_random_seed_loopback, save_each,
skip_grid, sort_samples, sampler_name, ddim_eta, n_iter, batch_size, i, denoising_strength, resize_mode, skip_metadata=True)
save_sample(gfpgan_image, sample_path_i, gfpgan_filename, jpg_sample, prompts, seeds, width, height, steps, cfg_scale,
normalize_prompt_weights, use_GFPGAN, write_info_files, write_sample_info_to_log_file, prompt_matrix, init_img, uses_loopback, uses_random_seed_loopback, skip_save,
skip_grid, sort_samples, sampler_name, ddim_eta, n_iter, batch_size, i, denoising_strength, resize_mode, skip_metadata=True)
output_images.append(gfpgan_image) #287
# save_each = True # #287 >_>
#if simple_templating:
# grid_captions.append( captions[i] + "\ngfpgan" )
@ -1140,30 +990,26 @@ def process_images(
esrgan_filename = original_filename + '-esrgan4x'
esrgan_sample = output[:,:,::-1]
esrgan_image = Image.fromarray(esrgan_sample)
if save_each:
save_sample(esrgan_image, sample_path_i, esrgan_filename, jpg_sample, prompts, seeds, width, height, steps, cfg_scale,
normalize_prompt_weights, use_GFPGAN, write_info_files, write_sample_info_to_log_file, prompt_matrix, init_img, uses_loopback, uses_random_seed_loopback, save_each,
skip_grid, sort_samples, sampler_name, ddim_eta, n_iter, batch_size, i, denoising_strength, resize_mode, skip_metadata=True)
save_sample(esrgan_image, sample_path_i, esrgan_filename, jpg_sample, prompts, seeds, width, height, steps, cfg_scale,
normalize_prompt_weights, use_GFPGAN,write_info_files, write_sample_info_to_log_file, prompt_matrix, init_img, uses_loopback, uses_random_seed_loopback, skip_save,
skip_grid, sort_samples, sampler_name, ddim_eta, n_iter, batch_size, i, denoising_strength, resize_mode, skip_metadata=True)
output_images.append(esrgan_image) #287
# save_each = False # #287 >_>
#if simple_templating:
# grid_captions.append( captions[i] + "\nesrgan" )
if use_RealESRGAN and RealESRGAN is not None and use_GFPGAN and GFPGAN is not None:
skip_save = True # #287 >_>
torch_gc()
cropped_faces, restored_faces, restored_img = GFPGAN.enhance(original_sample[:,:,::-1], has_aligned=False, only_center_face=False, paste_back=True)
cropped_faces, restored_faces, restored_img = GFPGAN.enhance(x_sample[:,:,::-1], has_aligned=False, only_center_face=False, paste_back=True)
gfpgan_sample = restored_img[:,:,::-1]
output, img_mode = RealESRGAN.enhance(gfpgan_sample[:,:,::-1])
gfpgan_esrgan_filename = original_filename + '-gfpgan-esrgan4x'
gfpgan_esrgan_sample = output[:,:,::-1]
gfpgan_esrgan_image = Image.fromarray(gfpgan_esrgan_sample)
if save_each:
save_sample(gfpgan_esrgan_image, sample_path_i, gfpgan_esrgan_filename, jpg_sample, prompts, seeds, width, height, steps, cfg_scale,
normalize_prompt_weights, use_GFPGAN, write_info_files, write_sample_info_to_log_file, prompt_matrix, init_img, uses_loopback, uses_random_seed_loopback, save_each,
skip_grid, sort_samples, sampler_name, ddim_eta, n_iter, batch_size, i, denoising_strength, resize_mode, skip_metadata=True)
save_sample(gfpgan_esrgan_image, sample_path_i, gfpgan_esrgan_filename, jpg_sample, prompts, seeds, width, height, steps, cfg_scale,
normalize_prompt_weights, use_GFPGAN, write_info_files, write_sample_info_to_log_file, prompt_matrix, init_img, uses_loopback, uses_random_seed_loopback, skip_save,
skip_grid, sort_samples, sampler_name, ddim_eta, n_iter, batch_size, i, denoising_strength, resize_mode, skip_metadata=True)
output_images.append(gfpgan_esrgan_image) #287
# save_each = False # #287 >_>
#if simple_templating:
# grid_captions.append( captions[i] + "\ngfpgan_esrgan" )
@ -1171,30 +1017,15 @@ def process_images(
if imgProcessorTask == True:
output_images.append(image)
if save_each:
if not skip_save:
save_sample(image, sample_path_i, filename, jpg_sample, prompts, seeds, width, height, steps, cfg_scale,
normalize_prompt_weights, use_GFPGAN, write_info_files, write_sample_info_to_log_file, prompt_matrix, init_img, uses_loopback, uses_random_seed_loopback, save_each,
skip_grid, sort_samples, sampler_name, ddim_eta, n_iter, batch_size, i, denoising_strength, resize_mode, False)
normalize_prompt_weights, use_GFPGAN, write_info_files, write_sample_info_to_log_file, prompt_matrix, init_img, uses_loopback, uses_random_seed_loopback, skip_save,
skip_grid, sort_samples, sampler_name, ddim_eta, n_iter, batch_size, i, denoising_strength, resize_mode, False)
if add_original_image or not simple_templating:
output_images.append(image)
if simple_templating:
grid_captions.append( captions[i] )
# Save the progress images?
if job_info:
if job_info.rec_steps_enabled and (job_info.rec_steps_to_file or job_info.rec_steps_to_gallery):
steps_grid = image_grid(job_info.rec_steps_imgs, 1)
if job_info.rec_steps_to_gallery:
gallery_img_size = tuple(2*dim for dim in image.size)
output_images.append( steps_grid.resize( gallery_img_size ) )
if job_info.rec_steps_to_file:
steps_grid_filename = f"{original_filename}_step_grid"
save_sample(steps_grid, sample_path_i, steps_grid_filename, jpg_sample, prompts, seeds, width, height, steps, cfg_scale,
normalize_prompt_weights, use_GFPGAN, write_info_files, write_sample_info_to_log_file, prompt_matrix, init_img, uses_loopback, uses_random_seed_loopback, skip_save,
skip_grid, sort_samples, sampler_name, ddim_eta, n_iter, batch_size, i, denoising_strength, resize_mode, False)
if opt.optimized:
mem = torch.cuda.memory_allocated()/1e6
modelFS.to("cpu")
@ -1263,7 +1094,7 @@ def txt2img(prompt: str, ddim_steps: int, sampler_name: str, toggles: List[int],
seed = seed_to_int(seed)
prompt_matrix = 0 in toggles
normalize_prompt_weights = 1 in toggles
save_each = 2 in toggles
skip_save = 2 not in toggles
skip_grid = 3 not in toggles
sort_samples = 4 in toggles
write_info_files = 5 in toggles
@ -1302,8 +1133,8 @@ def txt2img(prompt: str, ddim_steps: int, sampler_name: str, toggles: List[int],
def init():
pass
def sample(init_data, x, conditioning, unconditional_conditioning, sampler_name, img_callback: Callable = None):
samples_ddim, _ = sampler.sample(S=ddim_steps, conditioning=conditioning, batch_size=int(x.shape[0]), shape=x[0].shape, verbose=False, unconditional_guidance_scale=cfg_scale, unconditional_conditioning=unconditional_conditioning, eta=ddim_eta, x_T=x, img_callback=img_callback)
def sample(init_data, x, conditioning, unconditional_conditioning, sampler_name):
samples_ddim, _ = sampler.sample(S=ddim_steps, conditioning=conditioning, batch_size=int(x.shape[0]), shape=x[0].shape, verbose=False, unconditional_guidance_scale=cfg_scale, unconditional_conditioning=unconditional_conditioning, eta=ddim_eta, x_T=x)
return samples_ddim
try:
@ -1314,7 +1145,7 @@ def txt2img(prompt: str, ddim_steps: int, sampler_name: str, toggles: List[int],
prompt=prompt,
seed=seed,
sampler_name=sampler_name,
save_each=save_each,
skip_save=skip_save,
skip_grid=skip_grid,
batch_size=batch_size,
n_iter=n_iter,
@ -1393,9 +1224,14 @@ class Flagging(gr.FlaggingCallback):
print("Logged:", filenames[0])
def img2img(prompt: str, image_editor_mode: str, init_info: any, init_info_mask: any, mask_mode: str, mask_blur_strength: int, ddim_steps: int, sampler_name: str,
def img2img(prompt: str, image_editor_mode: str, mask_mode: str, mask_blur_strength: int, ddim_steps: int, sampler_name: str,
toggles: List[int], realesrgan_model_name: str, n_iter: int, cfg_scale: float, denoising_strength: float,
seed: int, height: int, width: int, resize_mode: int, fp = None, job_info: JobInfo = None):
seed: int, height: int, width: int, resize_mode: int, init_info: any = None, init_info_mask: any = None, fp = None, job_info: JobInfo = None):
print([prompt, image_editor_mode, init_info, init_info_mask, mask_mode,
mask_blur_strength, ddim_steps, sampler_name, toggles,
realesrgan_model_name, n_iter, cfg_scale,
denoising_strength, seed, height, width, resize_mode,
fp])
outpath = opt.outdir_img2img or opt.outdir or "outputs/img2img-samples"
err = False
seed = seed_to_int(seed)
@ -1406,7 +1242,7 @@ def img2img(prompt: str, image_editor_mode: str, init_info: any, init_info_mask:
normalize_prompt_weights = 1 in toggles
loopback = 2 in toggles
random_seed_loopback = 3 in toggles
save_each = 4 in toggles
skip_save = 4 not in toggles
skip_grid = 5 not in toggles
sort_samples = 6 in toggles
write_info_files = 7 in toggles
@ -1441,44 +1277,35 @@ def img2img(prompt: str, image_editor_mode: str, init_info: any, init_info_mask:
raise Exception("Unknown sampler: " + sampler_name)
if image_editor_mode == 'Mask':
global returned_info
init_img = init_info_mask["image"]
init_img = init_img.convert("RGB")
init_img = resize_image(resize_mode, init_img, width, height)
image = image.convert("RGB")
init_img = init_img.convert("RGB")
init_mask = init_info_mask["mask"]
init_mask = resize_image(resize_mode, init_mask, width, height)
resize_mask = mask_mode == 2
if resize_mask:
returned_info = crop_image(init_img, init_mask, width, height)
init_img = returned_info["cropped_img"]
init_mask = returned_info["cropped_mask"]
keep_mask = mask_mode == 0
init_mask = init_mask.convert("RGB")
init_mask = resize_image(resize_mode, init_mask, width, height)
init_mask = init_mask.convert("RGB")
keep_mask = mask_mode == 0
init_mask = init_mask if keep_mask else ImageOps.invert(init_mask)
else:
init_img = init_info.convert("RGB")
init_img = init_info
init_mask = None
keep_mask = False
resize_mask = False
assert 0. <= denoising_strength <= 1., 'can only work with strength in [0.0, 1.0]'
t_enc = int(denoising_strength * ddim_steps)
def init():
image = init_img.convert("RGB")
if resize_mask:
image = resize_image(resize_mode, image, width, height)
#image = image.convert("RGB") #todo: mask mode -> ValueError: could not convert string to float:
image = resize_image(resize_mode, image, width, height)
#image = image.convert("RGB")
image = np.array(image).astype(np.float32) / 255.0
image = image[None].transpose(0, 3, 1, 2)
image = torch.from_numpy(image)
mask_channel = None
if image_editor_mode == "Uncrop":
alpha = init_img.convert("RGB")
alpha = init_img.convert("RGBA")
alpha = resize_image(resize_mode, alpha, width // 8, height // 8)
mask_channel = alpha.split()[-1]
mask_channel = mask_channel.filter(ImageFilter.GaussianBlur(4))
@ -1486,7 +1313,7 @@ def img2img(prompt: str, image_editor_mode: str, init_info: any, init_info_mask:
mask_channel[mask_channel >= 255] = 255
mask_channel[mask_channel < 255] = 0
mask_channel = Image.fromarray(mask_channel).filter(ImageFilter.GaussianBlur(2))
elif init_mask is not None:
elif image_editor_mode == "Mask":
alpha = init_mask.convert("RGBA")
alpha = resize_image(resize_mode, alpha, width // 8, height // 8)
mask_channel = alpha.split()[1]
@ -1514,7 +1341,7 @@ def img2img(prompt: str, image_editor_mode: str, init_info: any, init_info_mask:
return init_latent, mask,
def sample(init_data, x, conditioning, unconditional_conditioning, sampler_name, img_callback: Callable = None):
def sample(init_data, x, conditioning, unconditional_conditioning, sampler_name):
t_enc_steps = t_enc
obliterate = False
if ddim_steps == t_enc_steps:
@ -1536,7 +1363,7 @@ def img2img(prompt: str, image_editor_mode: str, init_info: any, init_info_mask:
sigma_sched = sigmas[ddim_steps - t_enc_steps - 1:]
model_wrap_cfg = CFGMaskedDenoiser(sampler.model_wrap)
samples_ddim = K.sampling.__dict__[f'sample_{sampler.get_sampler_name()}'](model_wrap_cfg, xi, sigma_sched, extra_args={'cond': conditioning, 'uncond': unconditional_conditioning, 'cond_scale': cfg_scale, 'mask': z_mask, 'x0': x0, 'xi': xi}, disable=False, callback=partial(KDiffusionSampler.img_callback_wrapper, img_callback))
samples_ddim = K.sampling.__dict__[f'sample_{sampler.get_sampler_name()}'](model_wrap_cfg, xi, sigma_sched, extra_args={'cond': conditioning, 'uncond': unconditional_conditioning, 'cond_scale': cfg_scale, 'mask': z_mask, 'x0': x0, 'xi': xi}, disable=False)
else:
x0, z_mask = init_data
@ -1563,7 +1390,17 @@ def img2img(prompt: str, image_editor_mode: str, init_info: any, init_info_mask:
history = []
initial_seed = None
do_color_correction = False
try:
from skimage import exposure
do_color_correction = True
except:
print("Install scikit-image to perform color correction on loopback")
for i in range(n_iter):
if do_color_correction and i == 0:
correction_target = cv2.cvtColor(np.asarray(init_img.copy()), cv2.COLOR_RGB2LAB)
output_images, seed, info, stats = process_images(
outpath=outpath,
func_init=init,
@ -1571,7 +1408,7 @@ def img2img(prompt: str, image_editor_mode: str, init_info: any, init_info_mask:
prompt=prompt,
seed=seed,
sampler_name=sampler_name,
save_each=save_each,
skip_save=skip_save,
skip_grid=skip_grid,
batch_size=1,
n_iter=1,
@ -1605,6 +1442,17 @@ def img2img(prompt: str, image_editor_mode: str, init_info: any, init_info_mask:
initial_seed = seed
init_img = output_images[0]
if do_color_correction and correction_target is not None:
init_img = Image.fromarray(cv2.cvtColor(exposure.match_histograms(
cv2.cvtColor(
np.asarray(init_img),
cv2.COLOR_RGB2LAB
),
correction_target,
channel_axis=2
), cv2.COLOR_LAB2RGB).astype("uint8"))
if not random_seed_loopback:
seed = seed + 1
else:
@ -1630,7 +1478,7 @@ def img2img(prompt: str, image_editor_mode: str, init_info: any, init_info_mask:
prompt=prompt,
seed=seed,
sampler_name=sampler_name,
save_each=save_each,
skip_save=skip_save,
skip_grid=skip_grid,
batch_size=batch_size,
n_iter=n_iter,
@ -1655,7 +1503,6 @@ def img2img(prompt: str, image_editor_mode: str, init_info: any, init_info_mask:
write_info_files=write_info_files,
write_sample_info_to_log_file=write_sample_info_to_log_file,
jpg_sample=jpg_sample,
resize_mask=resize_mask,
job_info=job_info
)
@ -1723,10 +1570,9 @@ def imgproc(image,image_batch,imgproc_prompt,imgproc_toggles, imgproc_upscale_to
output = []
images = []
def processGFPGAN(image,strength):
cvimage = toImgOpenCV(image)
cropped_faces, restored_faces, restored_img = GFPGAN.enhance(np.array(cvimage, dtype=np.uint8), has_aligned=False, only_center_face=False, paste_back=True)
#save restored image
result = toImgPIL(restored_img)
image = image.convert("RGB")
cropped_faces, restored_faces, restored_img = GFPGAN.enhance(np.array(image, dtype=np.uint8), has_aligned=False, only_center_face=False, paste_back=True)
result = Image.fromarray(restored_img)
if strength < 1.0:
result = Image.blend(image, result, strength)
@ -1764,7 +1610,7 @@ def imgproc(image,image_batch,imgproc_prompt,imgproc_toggles, imgproc_upscale_to
height = int(imgproc_height)
cfg_scale = float(imgproc_cfg)
denoising_strength = float(imgproc_denoising)
save_each = True
skip_save = True
skip_grid = True
prompt = imgproc_prompt
t_enc = int(denoising_strength * ddim_steps)
@ -1918,7 +1764,7 @@ def imgproc(image,image_batch,imgproc_prompt,imgproc_toggles, imgproc_upscale_to
prompt=prompt,
seed=seed,
sampler_name=sampler_name,
save_each=save_each,
skip_save=skip_save,
skip_grid=skip_grid,
batch_size=batch_size,
n_iter=n_iter,
@ -1967,7 +1813,6 @@ def imgproc(image,image_batch,imgproc_prompt,imgproc_toggles, imgproc_upscale_to
return result
if image_batch != None:
if image != None:
print("Batch detected and single image detected, please only use one of the two. Aborting.")
@ -2106,14 +1951,15 @@ def ModelLoader(models,load=False,unload=False,imgproc_realesrgan_model_name='Re
def run_GFPGAN(image, strength):
ModelLoader(['LDSR','RealESRGAN'],False,True)
ModelLoader(['GFPGAN'],True,False)
cvimage = toImgOpenCV(image)
cropped_faces, restored_faces, restored_img = GFPGAN.enhance(np.array(cvimage, dtype=np.uint8), has_aligned=False, only_center_face=False, paste_back=True)
#save restored image
result = toImgPIL(restored_img)
if strength < 1.0:
result = Image.blend(image, result, strength)
image = image.convert("RGB")
return result
cropped_faces, restored_faces, restored_img = GFPGAN.enhance(np.array(image, dtype=np.uint8), has_aligned=False, only_center_face=False, paste_back=True)
res = Image.fromarray(restored_img)
if strength < 1.0:
res = Image.blend(image, res, strength)
return res
def run_RealESRGAN(image, model_name: str):
ModelLoader(['GFPGAN','LDSR'],False,True)
@ -2195,9 +2041,9 @@ imgproc_mode_toggles = [
'Upscale'
]
sample_img2img = "assets/stable-samples/img2img/sketch-mountains-input.jpg"
sample_img2img = sample_img2img if os.path.exists(sample_img2img) else None
#sample_img2img = "assets/stable-samples/img2img/sketch-mountains-input.jpg"
#sample_img2img = sample_img2img if os.path.exists(sample_img2img) else None
sample_img2img = None
# make sure these indicies line up at the top of img2img()
img2img_toggles = [
'Create prompt matrix (separate multiple prompts using |, and get all combinations of them)',
@ -2226,7 +2072,6 @@ img2img_resize_modes = [
"Just resize",
"Crop and resize",
"Resize and fill",
"Resize Masked Area"
]
img2img_defaults = {
@ -2262,22 +2107,13 @@ def update_image_mask(cropped_image, resize_mode, width, height):
resized_cropped_image = resize_image(resize_mode, cropped_image, width, height) if cropped_image else None
return gr.update(value=resized_cropped_image)
def copy_img_to_input(img):
try:
image_data = re.sub('^data:image/.+;base64,', '', img)
processed_image = Image.open(BytesIO(base64.b64decode(image_data)))
tab_update = gr.update(selected='img2img_tab')
img_update = gr.update(value=processed_image)
return {img2img_image_mask: processed_image, img2img_image_editor: img_update, tabs: tab_update}
except IndexError:
return [None, None]
def copy_img_to_upscale_esrgan(img):
update = gr.update(selected='realesrgan_tab')
image_data = re.sub('^data:image/.+;base64,', '', img)
processed_image = Image.open(BytesIO(base64.b64decode(image_data)))
return {realesrgan_source: processed_image, tabs: update}
return {'realesrgan_source': processed_image, 'tabs': update}
help_text = """